Now Reading
Python Guide To Google’s T5 Transformer For Text Summarizer

Python Guide To Google’s T5 Transformer For Text Summarizer

With the towering advancements in Transfer Learning, Deep Learning has achieved miraculous wonders. Especially in Natural language Processing, dominating these with the rise of Transformers, numerous approaches have arisen in the application of Language Modelling. We transfer the learning of a big model (mostly state-of-the-art) by pre-training it on a huge data corpus for a generic task, and the rest is fine-tuning it for specific tasks.   

In this blog, we will discuss Google AI’s state-of-the-art, T5 transformer which is a text to text transformer model. This was proposed earlier in 2020 in the paper “Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer”. The gist of the paper is a survey of the existing modern transfer learning techniques used in Natural Language Understanding, proposing a unified framework that will combine all language problems into a text-to-text format.

Deep Learning DevCon 2021 | 23-24th Sep | Register>>

Let’s take the example of BERT-style architecture. First, this model is trained on the objectives – Next Sentence Prediction and Masked Learning Models. Then we follow a simple procedure of fine-tuning the model for downstream tasks, predicting the label class in a classification problem or simple Question Answering. Hence, we separately fine-tune the model according to different tasks.

On the contrary, this text-to-text framework(input and output both sequences are text) suggest that we use the same hyperparameters, same loss function and even the same model(same checkpoints and layers)!!  For all the NLP tasks. The inputs are modeled so that the model shall recognize a task, and the output is simply the “text” version of the expected outcome.   

For examples, let’s check out the following sentences :

Follow us on Google News>>

“Cola sentence: course is jumping well ” -> T5 -> “not acceptable”

“Translate English to German: That is good” -> T5 -> “das ist gut”

“Summarize: some text in paragraph” -> T5 -> “summary”

As we can see in the above examples, it is demonstrated that we need to just add the task name in front of the input sequences (which can be done easily by string formatting )and feed it to the T5 transformer. Quite easy, right? Also, the output is in text format, which eliminates encoding and decoding.  

This transformer is trained on the C4 dataset (Colossal Clean Common Crawl), a super-cleaned counterpart of the normal standard Common Crawl containing web pages except the HTML markup. Super clean because this dataset has been cleaned of source codes, removing pages containing offensive words according to this list (dataset), removing duplicates and filtering out any pages containing non-English data. After all this, this resulted in a 750 GB huge dataset.   

Let’s dive into the code! 

Code Implementation of Text Summarization

Setup and Importing Dependencies

Keep in mind that the GPU should be on while making the notebook. Installing and importing all necessary libraries and utilities we need. A boilerplate for visualization  is also used to make plots with the same color scheme. Make sure the versions are updated of all the mentioned libraries.  

 # change the runtime and make sure it is set to GPU
 # if not then follow this
 # Runtime-> Change Runtime -> GPU
 # this will restart the runtime so make sure you do it at the start
 # hugging face transformers and pytorch lightning latest versions
 !pip install --quiet transformers==4.5.0
 !pip install --quiet pytorch-lightning==1.2.7
 # you can also check the type of GPU you 
 # by running the following command
 !nvidia -smi
 '''
 json for reading html text
 pandas for df manipulation
 numpy for array operations
 torch for pytorch lightning
 empty cache utility
 '''
 import json
 import pandas as pd
 import numpy as np
 import torch
 torch.cuda.empty_cache()
 # path for data
 from pathlib import Path
 # dataset and dataloader for functions
 from torch.utils.data import Dataset, DataLoader
 # lightning for data class
 import pytorch_lightning as pl
 # leveraging the model checkpoints
 from pytorch_lightning.callbacks import ModelCheckpoint
 # we can visualize performance of model
 from pytorch_lightning.loggers import TensorBoardLogger
 # splitting the data
 from sklearn.model_selection import train_test_split
 # color formatting in ANSII code for output in terminal
 from termcolor import colored
 # wraps the paragraph into a single line or string
 import textwrap
 # installing multiple utilities
 # including optimizer , tokenizer and generation module
 from transformers import (
     AdamW,
     T5ForConditionalGeneration,
     T5TokenizerFast as T5Tokenizer
 )
 # showing bars for processes in notebook
 from tqdm.auto import tqdm
 # seaborn for visualizing
 import seaborn as sns
 # procedural import to matplotlib
 from pylab import rcParams
 # graphs
 import matplotlib.pyplot as plt
 # rcParams for setting default values to all plots
 from matplotlib import rc
 # show graphs in the notebook cell
 %matplotlib inline
 # to render higher resolution images
 %config InlineBackend.figure_format='retina'
 # setting the default values for plots
 sns.set(style='whitegrid', palette='muted', font_scale=1.2)
 # make sure the fig size is not too big
 rcParams['figure.figsize'] = 16, 10
 # random pseudo numbers
 pl.seed_everything(42) 
Download the data and do manipulation

Here we will load the pre processed data which saves time from google drive. Then, further converting it into a Dataframe for manipulation. We have splitted the data into two parts namely train and test.

 # downloading the pre processed data
 !gdown --id 1DXsaWG9p3oQjkKu19mx_ZQ8UjHtsoGHW
 # taking a look at the data by df
 df = pd.read_csv("news_summary.csv", encoding='latin-1')
 df.head()
 # slicing out useful columns
 df = df[['text', 'ctext']]
 # show the first 5 data points
 df.head()
 # changing the names of the columns
 df.columns = ['summary', 'text']
 # dropping out the Not Available values
 df = df.dropna()
 df.head()
 # using sklearn utility, splitting the data into 10:1 ratio
 train_df, test_df = train_test_split(df, test_size=0.1)
 # let's check the shape of our data
 train_df.shape, test_df.shape 
Class for Dataset

This class contains functions required for initialising objects of the arguments we have to input for the pytorch model. We define the data , tokenizer , lengths of input and output sequences ,  take care of encoding the data and add paddings and special tokens.

 # class for creating the dataset which extends from pytorch 
 class NewsSummaryDataset(Dataset):
   # init it , create a constructor
     def __init__(
         self,
         # data in the form of a dataframe
         data: pd.DataFrame,
         # a tokenizer
         tokenizer: T5Tokenizer,
         # max token length of input sequence
         text_max_token_len: int = 512,
         # same for the summary but less length
         summary_max_token_len: int = 128
     ):
         # saving all
         self.tokenizer = tokenizer
         self.data = data
         self.text_max_token_len = text_max_token_len
         self.summary_max_token_len = summary_max_token_len
     # length method
     def __len__(self):
         return len(self.data)
     # getting the items method  
     def __getitem__(self, index: int):
       # data row from data at current index
         data_row = self.data.iloc[index]
         # get the full text
         text = data_row['text']
         # encoding the text
         text_encoding = tokenizer(
             text,
             # setting max length
             max_length=self.text_max_token_len,
             # for same length
             padding='max_length',
             # cutting longer sequences
             truncation=True,
             # masking unwanted words
             return_attention_mask=True,
             # special tokens for start and end
             add_special_tokens=True,
             # return pytorch tensors
             return_tensors='pt'
         )
         # same is done with summary encoding
         summary_encoding = tokenizer(
             data_row['summary'],
             truncation=True,
             return_attention_mask=True,
             add_special_tokens=True,
             max_length=self.summary_max_token_len,
             padding='max_length',
             return_tensors='pt'
         )
         # creating the actual labels
         labels = summary_encoding['input_ids'] 
         labels[labels == 0] = -100 # to make sure we have correct labels for T5 text generation
         return dict(
             # data
             text=text,
             # task
             summary=data_row['summary'],
             # easy batching
             text_input_ids=text_encoding['input_ids'].flatten(),
             # masking
             text_attention_mask=text_encoding['attention_mask'].flatten(),
             # again flatten
             labels=labels.flatten(),
             labels_attention_mask=summary_encoding['attention_mask'].flatten()
         ) 
Class for DataModule

This class will develop a data module (input method ) for plugging in the utilities as well as the data. Pytorch lightning is not used commonly but it makes sure fast model training and running. 

 # data module for pytorch lightning
 class NewsSummaryDataModule(pl.LightningDataModule):
     def __init__(
         self,
         # pass in train data
         train_df: pd.DataFrame,
         # pass in test data
         test_df: pd.DataFrame,
         # tokenizer
         tokenizer: T5Tokenizer,
         # batch_size
         batch_size: int = 8,
         # length of sequence
         text_max_token_len: int = 512,
         # length of output sequence
         summary_max_token_len: int = 128
     ):
         super().__init__()
         # storing the data in class objects
         self.train_df = train_df
         self.test_df = test_df
         self.batch_size = batch_size
         self.tokenizer = tokenizer
         self.text_max_token_len = text_max_token_len
         self.summary_max_token_len = summary_max_token_len
     # automatically called by the trainer  
     def setup(self, stage=None):
         self.train_dataset = NewsSummaryDataset(
             self.train_df,
             self.tokenizer,
             self.text_max_token_len,
             self.summary_max_token_len
         )
         self.test_dataset = NewsSummaryDataset(
             self.test_df,
             self.tokenizer,
             self.text_max_token_len,
             self.summary_max_token_len
         )
     # for train data
     def train_dataloader(self):
         return DataLoader(
             self.train_dataset,
             batch_size=self.batch_size,
             shuffle=True,
             num_workers=2
         )
   # for test data
     def test_dataloader(self):
         return DataLoader(
             self.test_dataset,
             batch_size=self.batch_size,
             shuffle=True,
             num_workers=2
         )
     # valid data
     def val_dataloader(self):
         return DataLoader(
             self.test_dataset,
             batch_size=self.batch_size,
             shuffle=True,
             num_workers=2
         ) 
Load, Fine-Tune, the Model

We have to load the T5 model having 222M params and instantiate it with the model with setting up the number of epochs and batch size. We also have to name the tokenizer (T5 tokenizer as it saves time) to be used ahead.

 # leveraging the base T5 transformer
 MODEL_NAME = 't5-base'
 # instantiate the tokenizer
 tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
 # empty lists which are to be filled
 text_token_counts, summary_token_counts = [], []
 # traversing train data
 for _, row in train_df.iterrows():
   # encode the data points
     text_token_count = len(tokenizer.encode(row['text']))
     text_token_counts.append(text_token_count)
     # do the same with the summary
     summary_token_count = len(tokenizer.encode(row['summary']))
     summary_token_counts.append(summary_token_count)   
 # plotting for token lengths and counts
 fig, (ax1, ax2) = plt.subplots(1, 2)
 sns.histplot(text_token_counts, ax=ax1)
 # the same for summary
 ax1.set_title('full text token counts')
 sns.histplot(summary_token_counts, ax=ax2)
 # number of epochs is less because number of parameters is high
 N_EPOCHS = 3
 BATCH_SIZE = 8
 # call the data module
 data_module = NewsSummaryDataModule(train_df, test_df, tokenizer) 
Class for Model Summary

Model summary is for defining the properties of the model and all the arguments including loss function , learning rate and optimizer,  so on. Separate functions for training step, validation and testing with quantities returned.

 # create lightning module for summarization
 class NewsSummaryModel(pl.LightningModule):
     def __init__(self):
         super().__init__()
         self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
     def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
         output = self.model(
             input_ids,
             attention_mask=attention_mask,
             labels=labels,
             decoder_attention_mask=decoder_attention_mask
         )
         return output.loss, output.logits
     def training_step(self, batch, batch_size):
         input_ids = batch['text_input_ids']
         attention_mask = batch['text_attention_mask']
         labels = batch['labels']
         labels_attention_mask = batch['labels_attention_mask']
         loss, outputs = self(
             input_ids=input_ids,
             attention_mask=attention_mask,
             decoder_attention_mask=labels_attention_mask,
             labels=labels
         )
         self.log("train_loss", loss, prog_bar=True, logger=True)
         return loss
     def validation_step(self, batch, batch_size):
         input_ids = batch['text_input_ids']
         attention_mask = batch['text_attention_mask']
         labels = batch['labels']
         labels_attention_mask = batch['labels_attention_mask']
         loss, outputs = self(
             input_ids=input_ids,
             attention_mask=attention_mask,
             decoder_attention_mask=labels_attention_mask,
             labels=labels
         )
         self.log("val_loss", loss, prog_bar=True, logger=True)
         return loss
     def test_step(self, batch, batch_size):
         input_ids = batch['text_input_ids']
         attention_mask = batch['text_attention_mask']
         labels = batch['labels']
         labels_attention_mask = batch['labels_attention_mask']
         loss, outputs = self(
             input_ids=input_ids,
             attention_mask=attention_mask,
             decoder_attention_mask=labels_attention_mask,
             labels=labels
         )
         self.log("test_loss", loss, prog_bar=True, logger=True)
         return loss
     def configure_optimizers(self):
         return AdamW(self.parameters(), lr=0.0001) 
Fit the Model
 model = NewsSummaryModel()
 %load_ext tensorboard
 %tensorboard --logdir ./lightning_logs
 checkpoint_callback = ModelCheckpoint(
     dirpath='checkpoints',
     filename='best-checkpoint',
     save_top_k=1,
     verbose=True,
     monitor='val_loss',
     mode='min'
 )
 logger = TensorBoardLogger("lightning_logs", name='news-summary')
 trainer = pl.Trainer(
     logger=logger,
     checkpoint_callback=checkpoint_callback,
     max_epochs=N_EPOCHS,
     gpus=1,
     progress_bar_refresh_rate=30
 )
 trainer.fit(model, data_module)
 trained_model = NewsSummaryModel.load_from_checkpoint(
     trainer.checkpoint_callback.best_model_path
 )
 trained_model.freeze() 
Summarize Text

Function and utilities for output function to be called. We will use this function for sample output in our code later on.

 def summarizeText(text):
     text_encoding = tokenizer(
         text,
         max_length=512,
         padding='max_length',
         truncation=True,
         return_attention_mask=True,
         add_special_tokens=True,
         return_tensors='pt'
     )
     generated_ids = trained_model.model.generate(
         input_ids=text_encoding['input_ids'],
         attention_mask=text_encoding['attention_mask'],
         max_length=150,
         num_beams=2,
         repetition_penalty=2.5,
         length_penalty=1.0,
         early_stopping=True
     )
     preds = [
             tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
             for gen_id in generated_ids
     ]
     return "".join(preds) 
Sample for Output
 sample_row = test_df.iloc[0]
 text = sample_row['text']
 model_summary = summarizeText(text)
 text
 sample_row['summary']
 model_summary
 sample_row = test_df.iloc[1]
 text = sample_row['text']
 model_summary = summarizeText(text)
 text
 sample_row['summary']
 Model_summary 
Saving the Model
 import pickle
 filename = open('text_summarization_model.pkl', 'wb')
 pickle.dump(trained_model.model, filename)
 # saving the model
 model = pickle.load(open('text_summarization_model.pkl', 'rb'))
 # function for producing output provided input
 def summarizeText(text):
     text_encoding = tokenizer(
         text,
         max_length=512,
         padding='max_length',
         truncation=True,
         return_attention_mask=True,
         add_special_tokens=True,
         return_tensors='pt'
     )
     generated_ids = model.generate(
         input_ids=text_encoding['input_ids'],
         attention_mask=text_encoding['attention_mask'],
         max_length=150,
         num_beams=2,
         repetition_penalty=2.5,
         length_penalty=1.0,
         early_stopping=True
     )
     preds = [
             tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
             for gen_id in generated_ids
     ]
     return "".join(preds) 

EndNote

We have seen a brief introduction to a spectacular all in one transformer model for NLP tasks and successfully implemented a beneficial application on this model. I recommend using different datasets (the dataset used in this article is on Kaggle) and try newer applications or implementations on this model.

References:

What Do You Think?

Join Our Discord Server. Be part of an engaging online community. Join Here.


Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top