Keras vs PyTorch vs Caffe – Comparing the Implementation of CNN

In this article, we will build the same deep learning framework that will be a convolutional neural network for image classification on the same dataset in Keras, PyTorch and Caffe and we will compare the implementation in all these ways. Finally, we will see how the CNN model built in PyTorch outperforms the peers built-in Keras and Caffe.
Keras vs PyTorch vs Caffe

In today’s world, Artificial Intelligence is imbibed in the majority of the business operations and quite easy to deploy because of the advanced deep learning frameworks. These deep learning frameworks provide the high-level programming interface which helps us in designing our deep learning models. Using deep learning frameworks, it reduces the work of developers by providing inbuilt libraries which allows us to build models more quickly and easily.

In this article, we will build the same deep learning framework that will be a convolutional neural network for image classification on the same dataset in Keras, PyTorch and Caffe and we will compare the implementation in all these ways. Finally, we will see how the CNN model built in PyTorch outperforms the peers built-in Keras and Caffe.

Topics covered in this article

  • How to choose Deep learning frameworks.
  • Pros and cons of Keras 
  • Pros and cons of Pytorch
  • Pros and cons of Caffe
  • Hands-on implementation of the CNN model in Keras, Pytorch & Caffe.

Choosing Deep Learning Frameworks

In choosing a Deep learning framework, There are some metrics to find the best framework, it should provide parallel computation, a good interface to run our models, a large number of inbuilt packages, it should optimize the performance and it is also based on our business problem and flexibility, these we are basic things to consider before choosing the Deep learning framework. Let’s compare three mostly used Deep learning frameworks Keras, Pytorch, and Caffe.

Deep learning framework in Keras 

Keras is an open-source framework developed by a Google engineer Francois Chollet and it is a deep learning framework easy to use and evaluate our models, by just writing a few lines of code. If you are new to deep learning, Keras is the best framework to start for beginners, Keras was created to be user friendly and easy to work with python and it has many pre-trained models(VGG, Inception..etc). Not only ease of learning but in the backend, it supports Tensorflow and is used in deploying our models. 

Limitations of using Keras

  • Keras need improvements in some features
  • We need to sacrifice speed for its user-friendliness
  • Sometimes it takes a huge time even using GPUs. 

Hands-on implementation Using Keras Framework

In the below code snippet we will import the required libraries.

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K


batch_size = 128
num_classes = 10
epochs = 12
img_rows, img_cols = 28, 28
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In the below code snippet we will build a deep learning model with few layers and assigning optimizers, activation functions and loss functions. 

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))

In the below code snippet, we will train and evaluate the model., y_train,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Deep Learning Framework in PyTorch

Pytorch is also an open-source framework developed by the Facebook research team, It is a pythonic way of implementing our deep learning models and it provides all the services and functionalities offered by the python environment, it allows auto differentiation that helps to speedup backpropagation process, PyTorch comes with various modules like torchvision, torchaudio, torchtext which is flexible to work in NLP, computer vision. Pytorch is more flexible for the researcher than developers.

Limitations of Pytorch 

  • Pytorch is more popular among researchers than developers.
  • It lacks in production.

Hands-on implementation Using Pytorch Framework.

Installing required libraries 

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import as dataloader
import torch.optim as optim
from import TensorDataset
from torchvision import transforms
from torchvision.datasets import MNIST

In the below code snippet we will load the dataset and split it into training and test sets.

train = MNIST('./data', train=True, download=True, transform=transforms.Compose([
]), )
test = MNIST('./data', train=False, download=True, transform=transforms.Compose([
]), )
dataloader_args = dict(shuffle=True, batch_size=64,num_workers=1, pin_memory=True)
train_loader = dataloader.DataLoader(train, **dataloader_args)
test_loader = dataloader.DataLoader(test, **dataloader_args)
train_data = train.train_data
train_data = train.transform(train_data.numpy())

In the below code snippet we will build our model, and assign activation functions and optimizers. 

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(784, 548)
        self.bc1 = nn.BatchNorm1d(548
        self.fc2 = nn.Linear(548, 252)
        self.bc2 = nn.BatchNorm1d(252)
        self.fc3 = nn.Linear(252, 10)              
    def forward(self, x):
        a = x.view((-1, 784))
        b = self.fc1(a)
        b = self.bc1(b)
        b = F.relu(b)
        b = F.dropout(b, p=0.5
        b = self.fc2(b)
        b = self.bc2(b)
        b = F.relu(b)
        b = F.dropout(b, p=0.2)
        b = self.fc3(b)
        out = F.log_softmax(b)
        return out
model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In the below code snippet we will train our model and while training we will assign loss function that is cross-entropy.

losses = []
for epoch in range(12):
    for batch_idx, (data,data_1) in enumerate(train_loader):
        data,data_1 = Variable(data.cuda()), Variable(target.cuda())
        y_pred = model(data) 
        loss = F.cross_entropy(y_pred, target)
        if batch_idx % 100 == 1:
            print('\r Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data), 
                len(train_loader.dataset), 100. * batch_idx / len(train_loader),

#Evaluating our model

output = model(evaluate)
predict =[1]
pred = pred.eq(
accuracy = pred.sum()/pred.size()[0]
print('Accuracy:', accuracy)

Deep Learning Framework in Caffe

Caffe(Convolutional Architecture for Fast Feature Embedding) is the open-source deep learning framework developed by Yangqing Jia. This framework supports both researchers and industrial applications in Artificial Intelligence. Most of the developers use Caffe for its speed, and it can process 60 million images per day with a single NVIDIA K40 GPU. Caffe has many contributors to update and maintain the frameworks, and Caffe works well in computer vision models compared to other domains in deep learning.

Limitation in Caffe

  • Caffe doesn’t have a higher-level API, so hard to do experiments.
  • In Caffe, for deploying our model we need to compile each source code. 

Installing Caffe 

!apt install -y caffe-tools-cpu

Importing required libraries 

import os
import numpy as np
import math
import caffe
import lmdb

In the below code snippet we will assign the hardware environment. 

os.environ["GLOG_minloglevel"] = '2'
USE_GPU = True

In the below code snippet we will define the image_generator and batch_generator which helps in data transformations.

def image_generator(db_path):
    db_handle =, readonly=True
    with db_handle.begin() as db:
        cur = db.cursor() 
        for _, value in cur: 
            datum = caffe.proto.caffe_pb2.Datum()
            int_x = 
            x = np.asfarray(int_x, dtype=np.float32) t
            yield x - 128 

def batch_generator(shape, db_path):
    gen = image_generator(db_path)
    res = np.zeros(shape) 
    while True
        for i in range(shape[0]):
            res[i] = next(gen) 

        yield res

In the below code snippet we will give the path of the MNIST dataset.

net_path = "content/mnist/lenet_train_test.prototxt"
net = caffe.Net(net_path, caffe.TRAIN)
test_net = caffe.Net(net_path, caffe.TEST) 

In the below code snippet we will train our model using MNIST dataset.

num_epochs = 0 
iter_num = 0 
db_path = "content/mnist/mnist_train_lmdb"
db_path_test = "content/mnist/mnist_test_lmdb"
base_lr = 0.01
gamma = 1e-4
power = 0.75

for epoch in range(num_epochs):
    print("Starting epoch {}".format(epoch))
    input_shape = net.blobs["data"].data.shape
    for batch in batch_generator(input_shape, db_path):
        iter_num += 1
        net.blobs["data"].data[...] = batch
        for name, l in zip(net._layer_names, net.layers):
            for b in l.blobs:
                b.diff[...] = net.blob_loss_weights[name]
        learning_rate = base_lr * math.pow(1 + gamma * iter_num, - power)
        for l in net.layers:
            for b in l.blobs:
      [...] -= learning_rate * b.diff
        if iter_num % 50 == 0:
            print("Iter {}: loss={}".format(iter_num, net.blobs["loss"].data))
        if iter_num % 200 == 0:
            print("Testing network: accuracy={}, loss={}".format(*test_network(test_net, db_path_test)))

Using the below code snippet, we will obtain the final accuracy.

print("Training finished after {} iterations".format(iter_num))
print("Final performance: accuracy={}, loss={}".format(*test_network(test_net, db_path_test)))


In this article, we demonstrated three famous frameworks in implementing a CNN model for image classification – Keras, PyTorch and Caffe. We could see that the CNN model developed in PyTorch has outperformed the CNN models developed in Keras and Caffe in terms of accuracy and speed. As a beginner, I started my research work using Keras which is a very easy framework for beginners but its applications are limited. But PyTorch and Caffe are very powerful frameworks in terms of speed, optimizing, and parallel computations.

Download our Mobile App

Prudhvi varma
AI enthusiast, Currently working with Analytics India Magazine. I have experience of working with Machine learning, Deep learning real-time problems, Neural networks, structuring and machine learning projects. I am a Computer Vision researcher and I am Interested in solving real-time computer vision problems.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox

6 IDEs Built for Rust

Rust IDEs aid efficient code development by offering features like code completion, syntax highlighting, linting, debugging tools, and code refactoring