Advertisement

Multi-Class Text Classification in PyTorch using TorchText

In this article, we will demonstrate the multi-class text classification using TorchText that is a powerful Natural Language Processing library in PyTorch. For this classification, a model will be used that is composed of the EmbeddingBag layer and linear layer. The EmbeddingBag deals with the text entries with varying length by computing the mean value of the bag of embeddings. This model will be trained on the DBpedia dataset with texts belonging to the 14 classes. After successful training, the model will predict the class label for the input text. 

Text Classification is one of the important applications of Natural Language Processing. There are a variety of ways in machine learning to classify texts. But the majority of these classification techniques required a large amount of preprocessing and consumed a lot of computational resources.

PyTorch provides a powerful way to implement complex model architectures and algorithms with comparatively less amount of preprocessing and the consumptions of computational resources including the execution time. The basic unit of PyTorch is tensor and it has the benefitting feature of changing the architecture during run time and distributed training across the GPUs. PyTorch provides a powerful library named TorchText that contains the scripts for preprocessing text and source of few popular NLP datasets.

In this article, we will demonstrate the multi-class text classification using TorchText that is a powerful Natural Language Processing library in PyTorch. For this classification, a model will be used that is composed of the EmbeddingBag layer and linear layer. The EmbeddingBag deals with the text entries with varying length by computing the mean value of the bag of embeddings. This model will be trained on the DBpedia dataset with texts belonging to the 14 classes. After successful training, the model will predict the class label for the input text. 

DBpedia Dataset

DBpedia is a popular benchmark dataset in the field of natural language processing. It contains texts belonging to the 14 classes such as Company, EducationalInstitutions, Artist, Film etc. It is actually the set of extracted structured contents from the information created in the Wikipedia project. The DBpedia dataset provided by the TorchText has 6,30,000 text instances belonging to the 14 classes. It comprises 5,60,000 training instances and 70,000 test instances.

Implementing Text Classification using TorchText

First of all, we need to install the latest version of TorchText.

!pip install torchtext==0.4
Text Classification in PyTorch using TorchText

After that, we will import all the required libraries.

import torch
import torchtext
from torchtext.datasets import text_classification
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time
from torch.utils.data.dataset import random_split
import re
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer

In the next step, we will define the ngrams and batch size. The ngrams feature is used to capture important information about the local word order. By using the bi-gram, the example text in the dataset will be a list of single words plus bi-grams string.

NGRAMS = 2
BATCH_SIZE = 16

Now, we will read the DBpedia dataset that is provided by the TorchText.

if not os.path.isdir('./.data'):
    os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['DBpedia'](
    root='./.data', ngrams=NGRAMS, vocab=None)
Text Classification in PyTorch using TorchText



After downloading the dataset, we will verify the length and the number of labels of the downloaded dataset.

print(len(train_dataset))
print(len(test_dataset))




print(len(train_dataset.get_labels()))
print(len(test_dataset.get_labels()))



We will use the CUDA architecture to speed-up the runtime and execution.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In the next step, we will define the model for the classification.

class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

print(model)

Text Classification in PyTorch using TorchText



Now, we will initialize the hyperparameters and define the function to generate the training batch. 

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text, offsets, label

In the next step, we will define the function to train and test the model.

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()

    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

We will train the model in 5 epochs.

N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
Text Classification in PyTorch using TorchText














In the next step, we will test our model on the test data set and check the accuracy of the model.

print('Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')


Now, we will test our model on individual news text strings and predict the class label for that given news text.

DBpedia_label = {0: 'Company',
                1: 'EducationalInstitution',
                2: 'Artist',
                3: 'Athlete',
                4: 'OfficeHolder',
                5: 'MeanOfTransportation',
                6: 'Building',
                7: 'NaturalPlace',
                8: 'Village',
                9: 'Animal',
                10: 'Plant',
                11: 'Album',
                12: 'Film',
                13: 'WrittenWork'}

def predict(text, model, vocab, ngrams):
    tokenizer = get_tokenizer("basic_english")
    with torch.no_grad():
        text = torch.tensor([vocab[token]
                            for token in ngrams_iterator(tokenizer(text), ngrams)])
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1
vocab = train_dataset.get_vocab()
model = model.to("cpu")

Now, we will take some random texts from the test data and check the predicted class label.

First prediction:

ex_text_str = "Brekke Church (Norwegian: Brekke kyrkje) is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of Bjørgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site."

print("This is a %s news" %DBpedia_label[predict(ex_text_str, model, vocab, 2)])

Second Prediction:

ex_text_str2 = "Cerithiella superba is a species of very small sea snail, a marine gastropod mollusk in the family Newtoniellidae. This species is known from European waters. It was described by Thiele, 1912."

print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str2, model, vocab, 2)])

Third Prediction:

ex_text_str3 = "  Nithari is a village in the western part of the state of Uttar Pradesh India bordering on New Delhi. Nithari forms part of the New Okhla Industrial Development Authority's planned industrial city Noida falling in Sector 31. Nithari made international news headlines in December 2006 when the skeletons of a number of apparently murdered women and children were unearthed in the village."

print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str3, model, vocab, 2)])
Text Classification in PyTorch using TorchText

So, in this way, we have implemented the multi-class text classification using the TorchText. It is a simple and easy way of text classification with very less amount of preprocessing using this PyTorch library. It took less than 5 minutes to train the model on 5,60,000 training instances. You re-implement this by changing the ngrams from 2 to 3 and see the results. The same implementation can be done on the other datasets provided by TorchText

References:-

  1. ‘Text Classification with TorchText’, PyTorch tutorial
  2. Allen Nie, ‘A Tutorial on TorchText’

Download our Mobile App

Dr. Vaibhav Kumar
Dr. Vaibhav Kumar is a seasoned data science professional with great exposure to machine learning and deep learning. He has good exposure to research, where he has published several research papers in reputed international journals and presented papers at reputed international conferences. He has worked across industry and academia and has led many research and development projects in AI and machine learning. Along with his current role, he has also been associated with many reputed research labs and universities where he contributes as visiting researcher and professor.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Upcoming Events

15th June | Bangalore

Future Ready | Lead the AI Era Summit

15th June | Online

Building LLM powered applications using LangChain

17th June | Online

Mastering LangChain: A Hands-on Workshop for Building Generative AI Applications

20th June | Bangalore

Women in Data Science (WiDS) by Intuit India

Jun 23, 2023 | Bangalore

MachineCon 2023 India

26th June | Online

Accelerating inference for every workload with TensorRT

MachineCon 2023 USA

Jul 21, 2023 | New York

Cypher 2023

Oct 11-13, 2023 | Bangalore

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
MOST POPULAR