NLP Deep Learning Training on Downstream tasks using Pytorch Lightning — Multiple Choice on SWAG data — Part 4 of 7

Narayana Swamy
5 min readJul 23, 2021

This is part 4 and a continuation of the Series. Please go to the Intro article that talks about the motivation for this Series. We will look at the various sections of the Multiple Choice Answer Training on the SWAG public data in the Colab Notebook and make appropriate comments for each of the sections.

  • Download and Import the Libraries — A library called torch metrics is installed — this is used to calculate the Accuracy of the predictions. Other than that, the regular Pytorch and Pytorch Lightning libraries are installed and imported.
  • Download the Data — The SWAG dataset is pre-processed and available for download from the Transformers Datasets Library. This dataset is used instead of pulling the data from the public repository. There are 73,546 Training samples, 20,006 Validation samples and 20,005 Test samples in the dataset. Each row of data looks like this:
{'ending0': 'are playing ping pong and celebrating one left each in quick.',  
'ending1': 'wait slowly towards the cadets.',
'ending2': 'continues to play as well along the crowd along with the band being interviewed.',
'ending3': 'continue to play marching, interspersed.',
'fold-ind': '3417',
'gold-source': 'gen',
'label': 3,
'sent1': 'A drum line passes by walking down the street playing their instruments.',
'sent2': 'Members of the procession',
'startphrase': 'A drum line passes by walking down the street playing their instruments. Members of the procession',
'video-id': 'anetv_jkn6uvmqwh4'}

There are 4 different possible endings for the sent2 text. sent1 is the beginning sentence. This data can be visualized as a Multiple Choice Question Answer like this:

Context: A drum line passes by walking down the street playing their instruments.   
A - Members of the procession are playing ping pong and celebrating one left each in quick.
B - Members of the procession wait slowly towards the cadets.
C - Members of the procession continues to play as well along the crowd along with the band being interviewed.
D - Members of the procession continue to play marching, interspersed.
Ground truth: option D

Given Sent1 and the 4 choices, the Trained model must be able to find the correct choice that is closely related to the Context.

  • Define the Pre-Trained Model — The Pre-Trained Model used here is the DistilBert-Base-uncased Model. It is a distilled version of BERT that is 60% faster, 40% lighter in memory, and still retains 97% of BERT’s performance. Once you have trained successfully with this, other pre-Trained models can be tried by changing the model_checkpoint variable.
  • Define the Pre-Process function or Dataset Class — Here we define the Pre-Process function that will create the train, val and test data in the Dataset format that is needed by the DataLoader. Pytorch uses a DataLoader class to build the data into mini-batches. The data is tokenized in this function using the pre-trained tokenizer. The Context is repeated four times as the first sentence and then tokenized along with each of the second sentences.
  • Define the DataModule Class — This is a Pytorch Lightning defined Class that contains all the code necessary to prepare the mini-batches of the data using the DataLoaders. At the start of the training, the Trainer class will call the prepare_data and setup functions first. There is a collate function here that does the padding of the mini-batches. Bert like models will require all the input data of a mini-batch to be of the same length. Instead of padding the input data to the longest length of the entire dataset, the collate function helps in padding the input data of the mini-batch to just the longest length of data within that mini-batch. This provides for faster training and less memory usage. The Label key is initially removed before using the tokenizer and then the label key is added at the end to the Batch
  • It is beneficial to talk about the batch dimensions. Say the batch size is 16 and the maximum number of word tokens in the batch is 50. Then the dimension of the batch[‘input_ids’] will be of 16x4x50. The 4 represents the four multiple choices. batch[‘labels] will be of 16x1.
  • Define the Model Class — the forward function of the DL Model is defined here. The ‘input_ids’ and ‘attention_mask’ in the batch is flattened first so that they will now have a dimension of 64 x50 instead of 16x4x50 — essentially sending each sentence combination (combo of sentence 1 and sentence 2) through the transformer model. The output of the CLS token from the last hidden layer (which is of shape — batch size, 1, 768) of the Transformer model is sent through a Dense Linear layer, a RELU, a Dropout layer and finally through a Linear Layer that outputs a dimension of 64x1. The output is reshaped to 16x4. The argmax of the reshaped output is then compared to the target label for loss calculation. The model is essentially trained to give a higher probability to the sentence 1 + sentence 2 combination within the four choices that is more likely to be the answer.
  • Define the Pytorch Lightning Module Class — This is where the training, validation and test step functions are defined. The model loss and accuracy are calculated in the step functions. The optimizers and schedulers are defined here as well. Nothing special to note here — just regular CrossEntropyLoss for loss calculation and use of torch metrics package to calculate Accuracy.
  • Define the Trainer Parameters — All the required Trainer parameters and Trainer callbacks are defined here. We have defined 3 different callbacks — EarlyStopping, LearningRate Monitor and Checkpointing. Instead of using the argparse to define the parameters, the latest Pytorch Lightning update allows the definition of the parameters using a .yaml file — this .yaml file can be provided as an argument to a python .py file in a CLI run. This way the Trainer parameters can be maintained separate from the Training code. Since we are using a Colab Notebook for the demo purposes, we stick with the argparse way.
  • Train the Model — This is done using the Trainer.fit() method. A profiler can be defined in the Trainer parameters to give more information on the Training run timings.
  • Evaluate Model Performance — The Test samples in this dataset doesn’t have any label and so wasn’t able to run the Test metrics. After 3 Epochs of Training, we get Validation metrics as below:
VALIDATE RESULTS 
{'val_accuracy': 0.6988,
'val_loss': 0.7514}
  • The validation accuracy score of 69.88% is much lower than the SOTA score of 86.6% achieved on the Dev data (it is assumed Dev data represents validation data) using a Bert Large model. It is not possible to fine tune a Bert Large model on the Colab free version as it will run out of memory. But it should be possible to reproduce the SOTA result by using a larger GPU machine.
  • Run Inference on the Trained Model — Send a sample batch text to the model to get a prediction from the trained model. This can be used in building the ML inference pipeline.
  • TensorBoard Logs Data — This will open TensorBoard within the Colab notebook and let you look at the various TensorBoard logs. Pytorch Lightning logs default to TensorBoard and this can be changed using a Logger callback.

Next we will take a look at the Question Answer task training in Part 5 of this Series.

--

--

Narayana Swamy

Over 17 years of diverse global experience in Data Science, Finance and Operations. Passionate about using data to unlock Business value.