How To Code Your First LSTM Network In Keras 

Normal Neural Networks are feedforward neural networks wherein the input data travels only in one direction i.e forward from the input nodes through the hidden layers and finally to the output layer. Recurrent Neural Networks, on the other hand, are a bit complicated. The data travels in cycles through different layers.

To put it a bit more technically, the data moves inside a Recurrent Neural Network along with directed cycles of paths between the nodes. This gives RNN a special ability compared to the regular Neural Networks. An ability that is vital when dealing with sequential data, the ability to learn dynamically and store what has been learned to predict.

In this article, we will implement a simple Recurrent Neural Network with Keras and MNIST dataset.

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Pre-requisites:




  • An understanding of Recurrent Neural Networks

Why RNN

Neural Networks also learn and remember what they have learnt, that’s how it predicts classes or values for new datasets, but what makes RNN’s different is that unlike normal Neural Networks, RNNs rely on the information from previous output to predict for the upcoming data/input. This feature becomes extremely useful when dealing with sequential data.

The simplest application of RNN is in Natural Language Processing. In all natural languages, the order of the words is important to convey the meaning in the right context. When it comes to predicting the next word of a sentence, the network must be familiar with what had come before the word it must predict. RNN can deal with any sequential data, including time series, video or audio sequences etc.

RNNs have a separate state or layer to store the output for a given input which is again used as input and hence the name recurrent.

A Simple Introduction To LSTM Network

So we know that RNNs are capable of remembering the characteristics of previous inputs and outputs. But for how long can it remember. For certain cases, the immediate previous output may not just be enough to predict what’s coming and the network may have to rely on information from a further previous output. 

For example, consider the phrase “the green grass” and a sentence “I live in France and I can speak French”. To predict the bold word in the first phrase, RNN can rely on its immediate previous output of green, on the other hand, to predict “french”, the Network has to overlook an output that is further away. This is called long-term dependency. Unfortunately as that gap between the words grows, RNNs become unable to learn to connect the information.

Long Short Term Memory or LSTM networks are a special kind of RNNs that deals with the long term dependency problem effectively. LSTM networks have a repeating module that has 4 different neural network layers interacting to deal with the long term dependency problem. You can read in detail about LSTM Networks here.

Let’s hand-code an LSTM network

Implementing LSTM with Keras

We will use the LSTM network to classify the MNIST data of handwritten digits.

Importing Necessary Modules

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import CuDNNLSTM, Dense, Dropout, LSTM
from keras.optimizers import Adam

Importing And Preprocessing MNIST Data

#Importing the data
(X_train, y_train),(X_test, y_test) = mnist.load_data()  # unpacks images to x_train/x_test and labels to y_train/y_test
#Normalizing the data
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

Creating An LSTM Network

#Initializing the classifier Network
classifier = Sequential()

#Adding the input LSTM network layer
classifier.add(CuDNNLSTM(128, input_shape=(X_train.shape[1:]), return_sequences=True))
classifier.add(Dropout(0.2))

Note:

The return_sequences parameter, when set to true, will return a sequence of output to the next layer. We set it to true since the next layer is also a Recurrent Network Layer

#Adding a second LSTM network layer
classifier.add(CuDNNLSTM(128))

#Adding a dense hidden layer
classifier.add(Dense(64, activation='relu'))
classifier.add(Dropout(0.2))

#Adding the output layer
classifier.add(Dense(10, activation='softmax'))

Note:

The CuDNNLSTM layer makes use of the CUDA framework to access the GPU resources. If you do not have a GPU you can use the LSTM layer instead, with an activation function.

Example:

classifier.add(LSTM(128, input_shape=(X_train.shape[1:]), return_sequences=True))

Compiling The LSTM Network And Fitting The Data

#Compiling the network
classifier.compile( loss='sparse_categorical_crossentropy',
              optimizer=Adam(lr=0.001, decay=1e-6),
              metrics=['accuracy'] )

#Fitting the data to the model
classifier.fit(X_train,
         y_train,
          epochs=3,
          validation_data=(X_test, y_test))

Output:

Checking The Accuracy On Test Set

test_loss, test_acc = classifier.evaluate(X_test, y_test)
print('Test Loss: {}'.format(test_loss))
print('Test Accuracy: {}'.format(test_acc))

Output:

Happy Coding !!

Amal Nair
A Computer Science Engineer turned Data Scientist who is passionate about AI and all related technologies. Contact: amal.nair@analyticsindiamag.com

Download our Mobile App

MachineHack

AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIM Research

Pioneering advanced AI market research

Request Customised Insights & Surveys for the AI Industry

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Strengthen Critical AI Skills with Trusted Corporate AI Training

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
MOST POPULAR