A Beginner’s Guide to Text Classification using BERT Features


We have seen many recent advancements in the field of data science which have made the task of practitioners very simple. BERT is such an advanced development in this field which has performed many complex tasks with very little effort. When it comes to complex text-based modelling, BERTs are preferred due to their ease of use and outperformance. In this article, we will do a hands-on implementation of text classification where we will use the text preprocessing and word embedding features for BERT. Through this article, we will be able to understand how simple is it to use the BERT for quickly building and using machine learning models. The major points to be covered in this article are listed in the below table of contents.

Table of Contents

  1. What Is BERT?
  2. What Is Word Embedding? 
  3. How Does BERT Work?
  4. Implementation Of BERT Model

Let’s begin the discussion by understanding what BERT is.


Sign up for your weekly dose of what's up in emerging technology.

What is BERT?

BERT or Bidirectional Encoder Representations from Transformers is a transformer-based machine learning technique for NLP. It is a pre-trained deep bidirectional representation from the unlabeled text by jointly conditioning on both left and right context. It is pre-trained on the English Wikipedia with 2,500M and wordsBooksCorpus with 800M words. 

Originally, there were two models for the English model the BERTBASE and  BERTLARGE.

In BERTBASE, there are 12 encoder layers, 768 feedforward networks and 12 attention heads present. In  BERTLARGE,  there are 24 encoder layers, 1024 feedforward networks, and 16 attention heads present.

What is Word Embedding? 

Word Embedding is the conversion of words in the document to vectors in which values assigned to them are closer in the vector space.

Image source 

In the above image, we can see that Boston and Seattle are near in vector space.

How Does BERT Work?

The first input token is supplied with a special token. Here special token is denoted by CLS and it stands for Classification. BERT takes a sequence of words, as input which keeps flowing up the stack. The Self-attention layer is applied to every layer and the result is passed through a feed-forward network and then to the next encoder. Each position outputs a vector of size 768 for a Base model which is the hidden_size. For the sentence classification tasks, we focus on the output of only the first position. This vector can now be used as the input for the classifier. 

Text Classification using BERT

Text Classification with BERT Features

Here, we will do a hands-on implementation where we will use the text preprocessing and word-embedding features of BERT and build a text classification model. This classification model will be used to predict whether a given message is spam or ham.

The dataset taken in this implementation is an open-source dataset from Kaggle. To download the data set, you can click this link.

Importing dependencies 

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.model_selection import train_test_split
import tensorflow_hub as hub
import tensorflow_text as text

If tensorflow_hub and  tensorflow_text are not found, install using the below code.

!pip install tensorflow_hub
!pip install tensorflow_text

Reading the file


To check the shape or the dimension and check the null values in the dataset, we can use the following codes:

As we saw the above output, columns ‘Unnamed: 2’, ‘Unnamed: 3’, and  ‘Unnamed: 4’  are dropped as they contain no values. 

df.drop(['Unnamed: 2', 'Unnamed: 3' ,'Unnamed: 4'], axis=1,inplace=True)

To check whether data is balanced or not we use the below lines of codes code. The first line of code gives output as the number of spam and ham in column v1. The second line of code will give output as a percentage of spam and ham in column v1.

Here we can see nearly 86% of data is ham and the rest is spam. To balance the data we use the downsampling method. First, we create two separate datasets for spam and ham. 

After performing the above steps, we convert ham and spam to numeric as shown by the below code.

df_balanced['spam']=df_balanced['v1'].apply(lambda x: 1 if x=='spam' else 0)

Now, we can split the dataset for training using the SK-Learn library.

x_train, x_test, y_train, y_test = train_test_split(df_balanced['v2'],df_balanced['spam'], stratify=df_balanced['spam'])

As we have our data ready for training and test purposes, we will download the BERT model for training and classification purposes. To download or copy the BERT model, you can click this link.

Text Classification using BERT

The below code will be used to download the BERT layers.

bert_preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
bert_encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")

After getting the BERT layers ready, we can build the BERT model using the below code.

# BERT layers
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)
l = tf.keras.layers.Dropout(0.1, name="dropout")(outputs['pooled_output'])
l = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l)
model = tf.keras.Model(inputs=[text_input], outputs = [l])


Here, we can see that only 769 trainable parameters and 109483010 non-trainable parameters are there. Next, we will train this model with the above-defined data.

Here we used only 2 epochs for demonstration purposes. For better accuracy, it can be increased. During the model training, we achieved an accuracy of about 81 and a precision of 80.

Now let’s make the predictions on the test data.

To check the performance of the classification model, that is confusion matrix, we use the following code, 

As we can see in the result confusion matrix, there are 171 and 160 true positives and true negatives, respectively, and 27 and 16 false negatives and false positives, respectively. Overall, the model looks to perform well on the test data.

Now we will try to make predictions with a set of text messages. We know that a value near 1 is a spam message and a value near 0 is a ham category as per the above training. The following code will predict the labels for some text message examples and give output in an array form.

In the above example, we can see the output as an array. The prediction values above 0.5 for the messages are considered spam and the same below 0.5 to be considered a ham. Now, we will write simple code to automate this prediction of spam and ham.

Final Words 

In this article, we have got to know the basics of the BERT model, model embedding, and how it works. The BERT model is implemented in this model to classify the SMS Spam collection dataset using pre-trained weights which are downloaded from the TensorFlow Hub repository. We could see how easily we can perform text classification using the word preprocessing and word embedding features of the BERT.


More Great AIM Stories

Basawanyya Hiremath
Basawanyya sees patterns around him. That's what makes him love machine learning, after all it's all about patterns around us.

Our Upcoming Events

Conference, in-person (Bangalore)
MachineCon 2022
24th Jun

Conference, Virtual
Deep Learning DevCon 2022
30th Jul

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM