BERT-Lite: Classification using Transformer in Tensorflow 2.0
Tensorflow 2.0 alpha is here and it is a much easier version to work with in a pythonic way compared to 1.x . Google open-sourced pre-trained versions of BERT in November 2018 but haven’t released a pre-trained version for tf 2.0 . If you are new to BERT, BERT is a state of the art deep learning language model based on the Transformer architecture that has beat a lot of previous benchmarks in language processing including GLUE, SQuAD etc. Here are a few links to important papers and links on BERT — Attention is all You need, BERT, The Illustrated Transformer, The Illustrated BERT, Deconstructing BERT, Deconstructing BERT, Part2.
I have played around with Bert in Pytorch using huggingface’s port of the Bert model parameters to Pytorch. I wanted to get my feet wet in Tensorflow 2.0 with all the exciting new features it offers. No better way than trying to build a BERT like transformer model in Tensorflow 2.0. The data used for the classification problem is the IMdb public dataset that has the movie reviews along with their associated binary sentiment polarity labels. The core dataset contains 50,000 reviews split evenly into 25k train and 25k test sets. The overall distribution of labels is balanced (25k pos and 25k neg).
This model borrows heavily from the tf 2.0 published example on language understanding using the Transformer architecture. That colab example describes in detail the encoder and decoder network. I took the Encoder Layer of the model and attached a binary classification layer to it. As suggested in the BERT paper, each sentence is encoded at the beginning with a so-called [CLS] token. For classification analysis, the vector representation of this [CLS] token at the end of the encoder network is fed into an output sigmoid layer. One could argue that it is better to do an average pooling of all the tokens at the end of the encoder layer but that is not what Google did. There may be real reasons why they didn’t do so.
# We “pool” the model by simply taking the hidden state
# corresponding to the first token.
enc_output = self.dense1(enc_output[:,0])
Epoch 1 Train_Loss 0.7173 Train_Accuracy 0.5194 Val_Loss 0.7574 Val_Accuracy 0.4900
Time taken for 1 epoch: 155.58440947532654 secs
Epoch 2 Train_Loss 0.6458 Train_Accuracy 0.6164 Val_Loss 0.6568 Val_Accuracy 0.6465
Time taken for 1 epoch: 128.04944705963135 secs
Epoch 3 Train_Loss 0.3961 Train_Accuracy 0.8212 Val_Loss 0.4630 Val_Accuracy 0.8165
Time taken for 1 epoch: 129.15209412574768 secs
Epoch 4 Train_Loss 0.3053 Train_Accuracy 0.8737 Val_Loss 0.3300 Val_Accuracy 0.8696
Time taken for 1 epoch: 128.44164967536926 secs
Saving checkpoint for epoch 5 at ./checkpoints/train/ckpt-1
Epoch 5 Train_Loss 0.2265 Train_Accuracy 0.9182 Val_Loss 0.3873 Val_Accuracy 0.8592
Time taken for 1 epoch: 131.58583331108093 secs
Epoch 6 Train_Loss 0.2029 Train_Accuracy 0.9219 Val_Loss 0.7408 Val_Accuracy 0.8262
Time taken for 1 epoch: 129.55460572242737 secs
Epoch 7 Train_Loss 0.1693 Train_Accuracy 0.9398 Val_Loss 0.5033 Val_Accuracy 0.7846
Time taken for 1 epoch: 130.63074278831482 secs
Epoch 8 Train_Loss 0.2424 Train_Accuracy 0.9277 Val_Loss 0.7738 Val_Accuracy 0.7846
Time taken for 1 epoch: 129.82275438308716 secs
Epoch 9 Train_Loss 0.1899 Train_Accuracy 0.9349 Val_Loss 0.5986 Val_Accuracy 0.8154
Time taken for 1 epoch: 131.50551176071167 secs
Saving checkpoint for epoch 10 at ./checkpoints/train/ckpt-2
Epoch 10 Train_Loss 0.1721 Train_Accuracy 0.9460 Val_Loss 0.5653 Val_Accuracy 0.8315
Time taken for 1 epoch: 135.71319437026978 secs
You can see the complete notebook here. A few things to take note:
- The performance shown in the model is not ideal — After Epoch 4, the model starts overfitting as evidenced by the gap in the training vs validation accuracy. More regularization techniques like relu, dropouts will be needed. Also not shown in the results is that after about 12 epochs, the training loss also starts increasing and this is more related to the learning rate not being optimized. Have to play around more with the learning rate scheduler to get the model to run for longer epochs.
- Need to attach tensorboard to the model. This requires using the model.fit keras function so that the tensorboard can be attached as a callback. I had trouble using the tensor dataset in the model.fit function. Need to spend more time on it.
- Keep in mind that this model is entirely trained on just the movie dataset and is unlikely to perform well on other language datasets. The open-sourced model published by Google was trained on huge amount of language data that it could be used as a starting point for transfer learning on other language datasets.
- This model only uses 6 encoder layers and a 512 encoding vector. Even the small sized Bert uncased model used 12-layer, 768-hidden vector size, 12-heads, 110M parameters.