NLP Deep Learning Training on Downstream tasks using Pytorch Lightning — Question Answering on SQUAD data — Part 5 of 7

Narayana Swamy
7 min readJul 23, 2021

This is part 5 and a continuation of the Series. Please go to the Intro article that talks about the motivation for this Series. We will look at the various sections of the Question Answer Training on the SQUAD public data in the Colab Notebook and make appropriate comments for each of the sections.

  • Download and Import the Libraries — The regular Pytorch, transformers , datasets and Pytorch Lightning libraries are installed and imported.
  • Download the Data —The Stanford Question Answering Dataset (SQuAD) comes in two flavors: SQuAD 1.1 and SQuAD 2.0. These reading comprehension datasets consist of questions posed on a set of Wikipedia articles, where the answer to every question is a segment (or span) of the corresponding passage. In SQuAD 1.1, all questions have an answer in the corresponding passage. SQuAD 2.0 steps up the difficulty by including questions that cannot be answered by the provided passage. We will be using the SQuAD 1.1 dataset from the Transformers Datasets Library. This dataset is used instead of pulling the data from the public repository (“https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"). There are 87,599 Training samples and 10,570 Validation samples in the dataset and does not contain any Test samples. Each row of data looks like this:
{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},  
'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
'id': '5733be284776f41900661182',
'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
'title': 'University_of_Notre_Dame'}
  • The answer is provided by two keys — the ‘answer_start’ that tells the character index where the Answer starts and the the ‘text’ gives the Actual answer text. The Answer is always contained within the Context provided. As you can see, the Context could be rather large and won’t fit within the 512 token limit of a Bert Model. This would require some pre-processing to be done on the data.
  • Define the Pre-Trained Model — The Pre-Trained Model used here is the DistilBert-Base-uncased Model. It is a distilled version of BERT that is 60% faster, 40% lighter in memory, and still retains 97% of BERT’s performance. Once you have trained successfully with this, other pre-Trained models can be tried by changing the model_checkpoint variable.
  • Define the Pre-Process function or Dataset Class — Here we define the Pre-Process function that will create the train and val data in the Dataset format that is needed by the DataLoader. Pytorch uses a DataLoader class to build the data into mini-batches. The data is tokenized in this function using the pre-trained tokenizer. The Question and Context are combined together and tokenized. Each tokenized chunk will follow this format : [CLS] question tokens [SEP] context tokens [SEP]
  • Since the Context could be rather large (larger than 512 tokens), the context is broken up into sections with a max of 384 tokens including question and context. The broken up context would overlap over 128 tokens (specified by the doc_stride)- to make sure a Question has access to the sentences in the Context that come before and after an Answer text. With a long Context, multiple examples are produced for one question and the proper start and end of the answer are tracked within the broken up Context. More details and explanation of this process can be found here.
  • Define the DataModule Class — This is a Pytorch Lightning defined Class that contains all the code necessary to prepare the mini-batches of the data using the DataLoaders. At the start of the training, the Trainer class will call the prepare_data and setup functions first. There is a collate function here that does the padding of the mini-batches. Bert like models will require all the input data of a mini-batch to be of the same length. Instead of padding the input data to the longest length of the entire dataset, the collate function helps in padding the input data of the mini-batch to just the longest length of data within that mini-batch. This provides for faster training and less memory usage. No special padding routine needed in this case.
  • Define the Model Class — the forward function of the DL Model is defined here. Assume a batch size of 16 and token size of 384. The output from the last hidden layer (which is of shape — 16 x 384 x 768) of the Transformer model is sent through a Linear layer that outputs a dimension of 16 x 384 x 2. The output is then split into a start_logits and end_logits with each of shape 16x384. The argmax of the start_logits would represent the starting token index and the argmax of the end_logits would represent the end token index. These are then compared to the labeled start_position and end_position of the answer for loss calculation. The model is essentially trained to find the start and end position of the answer within the context.
  • Define the Pytorch Lightning Module Class — This is where the training, validation and test step functions are defined. The model loss and accuracy are calculated in the step functions. The optimizers and schedulers are defined here as well. The CrossEntropyLoss is separately calculated for the Start and End positions of the answer and the loss is then averaged.
  • Define the Trainer Parameters — All the required Trainer parameters and Trainer callbacks are defined here. We have defined 3 different callbacks — EarlyStopping, LearningRate Monitor and Checkpointing. Instead of using the argparse to define the parameters, the latest Pytorch Lightning update allows the definition of the parameters using a .yaml file — this .yaml file can be provided as an argument to a python .py file in a CLI run. This way the Trainer parameters can be maintained separate from the Training code. Since we are using a Colab Notebook for the demo purposes, we stick with the argparse way.
  • Train the Model — This is done using the Trainer.fit() method. A profiler can be defined in the Trainer parameters to give more information on the Training run timings.
  • Evaluate the Model Performance — This is not very straightforward compared to the prior examples. The most obvious thing to predict an answer for each feature is to take the index for the maximum of the start logits as a start position and the index of the maximum of the end logits as an end position. This will work great in a lot of cases, but what if this prediction gives us something impossible: the start position could be greater than the end position, or point to a span of text in the question instead of the answer. In that case, we might want to look at the second best prediction to see if it gives a possible answer and select that instead.
  • However, picking the second best answer is not as easy as picking the best one: is it the second best index in the start logits with the best index in the end logits? Or the best index in the start logits with the second best index in the end logits? And if that second best answer is not possible either, it gets even trickier for the third best answer.
  • To classify the answers, a score is obtained by adding the start and end logits. Instead of preparing all possible answers from start and end logits combinations, a limit is placed on the indices to check using a hyper-parameter called n_best_size (like 20). The best indices in the start and end logits are taken and all possible answers are gathered. After checking if each one is valid, they are sorted by their score and the best one is kept.
  • The code for post-processing this evaluation is quite lengthy and I will refer you to the HuggingFace tutorial Notebook on this.
  • The SOTA F1 score for Squad 1.1 Question Answering is 95.38. After 3 Epochs, the notebook got a validation loss of 1.11 — which is close to a F1 score of 85.13 (based on Huggingface notebook results), much lower than the SOTA score. Training on more Epochs will be needed plus a bigger model like XLNet will need to be used to achieve F1 scores near 95.0.
  • Run Inference on the Trained Model — Shows how to send a sample batch text to the model to get a prediction from the trained model. This can be used in building the ML inference pipeline.
  • TensorBoard Logs Data — This will open TensorBoard within the Colab notebook and let you look at the various TensorBoard logs. Pytorch Lightning logs default to TensorBoard and this can be changed using a Logger callback.

Next we will take a look at the Summarization task training in Part 6 of this Series.

--

--

Narayana Swamy

Over 17 years of diverse global experience in Data Science, Finance and Operations. Passionate about using data to unlock Business value.