Now Reading
Multi-Class Text Classification in PyTorch using TorchText

Multi-Class Text Classification in PyTorch using TorchText

Text Classification is one of the important applications of Natural Language Processing. There are a variety of ways in machine learning to classify texts. But the majority of these classification techniques required a large amount of preprocessing and consumed a lot of computational resources.

PyTorch provides a powerful way to implement complex model architectures and algorithms with comparatively less amount of preprocessing and the consumptions of computational resources including the execution time. The basic unit of PyTorch is tensor and it has the benefitting feature of changing the architecture during run time and distributed training across the GPUs. PyTorch provides a powerful library named TorchText that contains the scripts for preprocessing text and source of few popular NLP datasets.

How To Start Your Career In Data Science?

In this article, we will demonstrate the multi-class text classification using TorchText that is a powerful Natural Language Processing library in PyTorch. For this classification, a model will be used that is composed of the EmbeddingBag layer and linear layer. The EmbeddingBag deals with the text entries with varying length by computing the mean value of the bag of embeddings. This model will be trained on the DBpedia dataset with texts belonging to the 14 classes. After successful training, the model will predict the class label for the input text. 

DBpedia Dataset

DBpedia is a popular benchmark dataset in the field of natural language processing. It contains texts belonging to the 14 classes such as Company, EducationalInstitutions, Artist, Film etc. It is actually the set of extracted structured contents from the information created in the Wikipedia project. The DBpedia dataset provided by the TorchText has 6,30,000 text instances belonging to the 14 classes. It comprises 5,60,000 training instances and 70,000 test instances.

Implementing Text Classification using TorchText

First of all, we need to install the latest version of TorchText.

!pip install torchtext==0.4
Text Classification in PyTorch using TorchText

After that, we will import all the required libraries.

import torch
import torchtext
from torchtext.datasets import text_classification
import os
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
import time
from import random_split
import re
from import ngrams_iterator
from import get_tokenizer

In the next step, we will define the ngrams and batch size. The ngrams feature is used to capture important information about the local word order. By using the bi-gram, the example text in the dataset will be a list of single words plus bi-grams string.


Now, we will read the DBpedia dataset that is provided by the TorchText.

if not os.path.isdir('./.data'):
train_dataset, test_dataset = text_classification.DATASETS['DBpedia'](
    root='./.data', ngrams=NGRAMS, vocab=None)
Text Classification in PyTorch using TorchText

After downloading the dataset, we will verify the length and the number of labels of the downloaded dataset.



We will use the CUDA architecture to speed-up the runtime and execution.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In the next step, we will define the model for the classification.

class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)

    def init_weights(self):
        initrange = 0.5, initrange), initrange)

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


Text Classification in PyTorch using TorchText

Now, we will initialize the hyperparameters and define the function to generate the training batch. 

VOCAB_SIZE = len(train_dataset.get_vocab())
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text =
    return text, offsets, label

In the next step, we will define the function to train and test the model.

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
    for i, (text, offsets, cls) in enumerate(data):
        text, offsets, cls =,,
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate

    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls =,,
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

We will train the model in 5 epochs.

min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):

    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
Text Classification in PyTorch using TorchText

In the next step, we will test our model on the test data set and check the accuracy of the model.

See Also

print('Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')

Now, we will test our model on individual news text strings and predict the class label for that given news text.

DBpedia_label = {0: 'Company',
                1: 'EducationalInstitution',
                2: 'Artist',
                3: 'Athlete',
                4: 'OfficeHolder',
                5: 'MeanOfTransportation',
                6: 'Building',
                7: 'NaturalPlace',
                8: 'Village',
                9: 'Animal',
                10: 'Plant',
                11: 'Album',
                12: 'Film',
                13: 'WrittenWork'}

def predict(text, model, vocab, ngrams):
    tokenizer = get_tokenizer("basic_english")
    with torch.no_grad():
        text = torch.tensor([vocab[token]
                            for token in ngrams_iterator(tokenizer(text), ngrams)])
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1
vocab = train_dataset.get_vocab()
model ="cpu")

Now, we will take some random texts from the test data and check the predicted class label.

First prediction:

ex_text_str = "Brekke Church (Norwegian: Brekke kyrkje) is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of Bjørgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site."

print("This is a %s news" %DBpedia_label[predict(ex_text_str, model, vocab, 2)])

Second Prediction:

ex_text_str2 = "Cerithiella superba is a species of very small sea snail, a marine gastropod mollusk in the family Newtoniellidae. This species is known from European waters. It was described by Thiele, 1912."

print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str2, model, vocab, 2)])

Third Prediction:

ex_text_str3 = "  Nithari is a village in the western part of the state of Uttar Pradesh India bordering on New Delhi. Nithari forms part of the New Okhla Industrial Development Authority's planned industrial city Noida falling in Sector 31. Nithari made international news headlines in December 2006 when the skeletons of a number of apparently murdered women and children were unearthed in the village."

print("This text belongs to %s class" %DBpedia_label[predict(ex_text_str3, model, vocab, 2)])
Text Classification in PyTorch using TorchText

So, in this way, we have implemented the multi-class text classification using the TorchText. It is a simple and easy way of text classification with very less amount of preprocessing using this PyTorch library. It took less than 5 minutes to train the model on 5,60,000 training instances. You re-implement this by changing the ngrams from 2 to 3 and see the results. The same implementation can be done on the other datasets provided by TorchText


  1. ‘Text Classification with TorchText’, PyTorch tutorial
  2. Allen Nie, ‘A Tutorial on TorchText’

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join our Telegram Group. Be part of an engaging community

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top