Exploring Transfer Learning Using TensorFlow Keras

Transfer Learning is the approach of making use of an already trained deep learning model along with its weights for a related task
Transfer Learning cover art

A good deep learning model has a carefully carved architecture. It needs enormous training data, effective hardware, skilled developers, and a vast amount of time to train and hyper-tune the model to achieve satisfactory performance. Therefore, building a deep learning model from scratch and training is practically impossible for every deep learning task. Here comes the power of Transfer Learning. Transfer Learning is the approach of making use of an already trained model for a related task. 

In this article, we discuss Transfer Learning with necessary examples to perform image classification using TensorFlow Keras. This article assumes that readers have good knowledge of the fundamentals of deep learning and computer vision. Readers may have a look at the following basic articles to fulfil the prerequisites:

  1. Getting Started With Deep Learning Using TensorFlow Keras
  2. Getting Started With Computer Vision Using TensorFlow Keras

How Transfer Learning works?

Each deep learning problem is unique in some sense. Therefore, it is hard to reuse an already trained model as such in new problems. It may need some task-specific alterations. Hence, we go for Transfer Learning to fulfil our task. The already trained model that is to be used via Transfer Learning is called a pre-trained model. A pre-trained model can be state-of-the-art in the domain. But it is necessary that our problem should belong to the same domain as that of the pre-trained model. For instance, a pre-trained model meant for image segmentation can not be utilized for image classification.

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Deep learning frameworks such as TensorFlow and PyTorch enable saving a model and its weights in portable formats. A carefully developed architecture can extract the features from the input data. This architecture, along with its weights, can extract features greatly from slightly different input data also. Thus, the main goal of transfer learning is to extract features details from a given data.

We explore two ways of applying Transfer Learning in the sequel:

  1. Feature Extraction only
  2. Feature Extraction and Fine-tuning

Feature Extraction Using a Pre-trained Model

Import the necessary frameworks and libraries.

 import numpy as np
 import pandas as pd
 import tensorflow as tf
 from tensorflow import keras
 import tensorflow_datasets as tfds
 import matplotlib.pyplot as plt 

TensorFlow Datasets has a huge collection of pre-processed and vectorized datasets from different domains. Here, we discuss feature extraction using transfer learning with image classification problems. Load in-built Oxford_Flowers102 dataset that has images of flowers from 102 different classes.

 data, meta = tfds.load('oxford_flowers102', 
                        as_supervised=True,
                        with_info=True) 

Output:

Obtain train, validation and test sets from the data.

 raw_train = data['train']
 raw_val = data['validation']
 raw_test = data['test'] 

Sample an image and display it with its label.

 label_extractor = meta.features['label'].int2str
 for image,label in raw_train.take(1):
   plt.imshow(image)
   plt.title(label_extractor(label))
   plt.colorbar()
   plt.show() 

Output:

Visual observation yields that the image is 3 channel colour image with pixel values ranging from 0 to 255. However, we can get exact bounding values using the following codes.

 print(image.shape)
 min(image.numpy().ravel()), max(image.numpy().ravel()) 

Output:

We get our problem (the image classification dataset). We need to choose a suitable pre-trained model. ImageNet is one of the most famous datasets used in image classification. It has more than a million images belonging to 1000 classes. Newly developed competing architectures are trained and tested with this dataset. NASNetLarge, developed purely by Google’s reinforcement learning environment (NAS- Neural Architecture Search) without human intervention, is one among the popular architectures. 

The convolution neural network part in the architecture is called the base, and the artificial neural network part (with Dense layers) is called the head. Base extracts features from the input images. Head makes classification using the extracted features. ImageNet has 1000 classes, but our dataset has only 102 classes. Hence, we can use only the base of the pre-trained NASNetLarge model as a feature extractor. 

 base = keras.applications.NASNetLarge(input_shape=(331,331,3),
                                       include_top=False,
                                       weights='imagenet') 

Output:

download pre-trained model

NASNetLarge expects its input to be in the shape of (331,331,3). We need to resize our images to conform to the requirements. Define a helper function to scale and resize the input images.

 def scale_image(img,label):
   img = tf.cast(img,tf.float32)
   img = img/255.0
   img = tf.image.resize(img, (331,331))
   return img,label  

Scale and resize the train, validation and test images.

 train = raw_train.map(scale_image)
 val = raw_val.map(scale_image)
 test = raw_test.map(scale_image) 

Sample some 25 images and display them with their text labels.

 plt.figure(figsize=(8,8))
 i=1
 for img, label in train.take(25):
   plt.subplot(5,5,i)
   plt.imshow(img)
   plt.title(label_extractor(label))
   plt.xticks([])
   plt.yticks([])
   i += 1
 plt.show() 

Output:

transfer learning task

Prepare data in batches as the optimizer expects it.

 batch_size = 64
 train_batches = train.batch(batch_size)
 val_batches = val.batch(batch_size) 

We use the pre-trained model’s base and its weights as such. Hence we should not train them again. 

base.trainable=False

Develop a classification head to classify 102 classes.

 head = keras.models.Sequential([
                                 keras.layers.GlobalAveragePooling2D(),
                                 keras.layers.Dense(512, activation='relu'),
                                 keras.layers.Dropout(0.5),
                                 keras.layers.Dense(4, activation='softmax')
 ]) 

Build the final model with the non-trainable base (feature extractor) and the head.

model = keras.models.Sequential([base,head])

We can explore the number of parameters in the model.

base.summary()

A portion of the output:

parameters in pre-trained model

The base has around 85 million parameters, none of which are trainable (pre-trained model). 

head.summary()

Output:

The head has around 2 million parameters, everything being trainable.

model.summary()

Output:

We compile the model with a suitable loss function, an optimizer and an evaluation metric.

 model.compile(loss='sparse_categorical_crossentropy',
               optimizer='adam',
               metrics=['accuracy']) 

Our job is to train the head of the model with the input data while the base remains as such. We can define an early stopping callback. It stops training when there is no remarkable improvement in the validation performance. Here, we provide 3 as the value to the argument patience. It is the number of epochs for which the training will continue even if there is no improvement in performance. For highly unstable performance (zig-zag performance curves), higher patience is preferred.

callback = keras.callbacks.EarlyStopping(monitor='val_loss',patience=3)

Train the model for 100 epochs. The early stopping callback will break training at some early epoch itself.

history = model.fit(train_batches, validation_data=val_batches, epochs=100, callbacks=callback)

A portion of the output:

transfer learning - training

We observe that the training has stopped just after the 30th epoch due to a decline in validation loss. Visualize the losses and accuracies to get a better insight about training.

 hist = pd.DataFrame(history.history)
 length = len(hist['loss'])
 epochs = np.arange(1,length+1)

 plt.plot(epochs,hist['loss'], label='Train Loss')
 plt.plot(epochs,hist['val_loss'], label='Val Loss')
 plt.legend()
 plt.ylabel('Loss')
 plt.xlabel('Epochs')
 plt.xticks(np.arange(1,length+1,2))
 plt.show()  

Output:

feature extraction in a pre-trained model
 plt.plot(epochs,hist['accuracy'], label='Train Accuracy')
 plt.plot(epochs,hist['val_accuracy'], label='Val Accuracy')
 plt.legend()
 plt.ylabel('Accuracy')
 plt.xlabel('Epochs')
 plt.xticks(np.arange(1,length+1,2))
 plt.show()  

Output:

feature extraction in a pre-trained model

Both training and validation performances get saturated at around 10th epoch. There is no remarkable improvement afterwards. This is exactly because the base was originally trained to extract features from ImageNet dataset. Our dataset may have some minor feature differences that the base can not extract. We handle this issue with another dataset in the sequel.

Fine-tuning a Pre-trained Model

When the present image features differ from originally trained images, we lack performances, even though the pre-trained model is a well-acclaimed state-of-the-art. We look at another dataset here, the Cassava Leaf Disease dataset, available in-built with TensorFlow Datasets. This problem has leaves in each image but with minute differences according to their disease states.

 cassava, meta = tfds.load('cassava',
                           with_info=True,
                           as_supervised=True) 

Output: 

cassava dataset

Obtain the split data.

 raw_train  = cassava['train']
 raw_val = cassava['validation']
 raw_test = cassava['test'] 

Metadata gives the details about the dataset.

meta

Output:

metadata

The dataset has 5 classes: one healthy and four different disease classes.

 for i in range(5):
   print(f'{i}: {label_extractor(i)}') 

Output:

Sample an image and visualize it.

 label_extractor = meta.features['label'].int2str
 for img, label in raw_train.skip(1).take(1):
   plt.imshow(img)
   plt.title(label_extractor(label))
   plt.colorbar()
   plt.show() 
single image
 print(img.shape)
 min(img.numpy().ravel()), max(img.numpy().ravel()) 

Output:

The images are 3 channel colour images with pixel values ranging from 0 to 255 as before. Scale and resize the images.

 train = raw_train.map(scale_image)
 val = raw_val.map(scale_image)
 test = raw_test.map(scale_image) 

Sample some 25 resized images and visualize them along with their labels.

 plt.figure(figsize=(8,8))
 i=1
 for img, label in train.take(25):
   plt.subplot(5,5,i)
   plt.imshow(img)
   plt.title(label_extractor(label))
   plt.xticks([])
   plt.yticks([])
   i += 1
 plt.show() 

Output: 

image dataset

Prepare batches of datasets.

 batch_size = 128
 train_batches = train.batch(batch_size)
 val_batches = val.batch(batch_size) 

Use the same feature extractor base. But, develop a new head to classify 5 classes.

 base.trainable=False
 head = keras.models.Sequential([
                                 keras.layers.GlobalAveragePooling2D(),
                                 keras.layers.Dense(256, activation='relu'),
                                 keras.layers.Dropout(0.5),
                                 keras.layers.Dense(5, activation='softmax')
 ])
 
 model = keras.models.Sequential([base,head]) 

Compile the model and visualize the number of parameters.

 model.compile(loss='sparse_categorical_crossentropy',
               optimizer='adam',
               metrics=['accuracy'])
 model.summary() 

Output:

As before, the base parameters are non-trainable, and the head parameters are trainable. Train the model for 5 epochs only.

history = model.fit(train_batches, validation_data=val_batches, epochs=5)

Visualize the performance.

 hist = history.history
 loss = hist['loss']
 val_loss = hist['val_loss']
 acc = hist['accuracy']
 val_acc = hist['val_accuracy']
 plt.plot(loss, label='Train Loss')
 plt.plot(val_loss, label='Val Loss')
 plt.legend()
 plt.ylabel('Loss')
 plt.xlabel('Epochs')
 plt.show()  

Output:

feature extraction pre-trained model
 plt.plot(acc, label='Train Accuracy')
 plt.plot(val_acc, label='Val Accuracy')
 plt.legend()
 plt.ylabel('Accuracy')
 plt.xlabel('Epochs')
 plt.show()  

Output:

feature extraction pre-trained model

After some initial training (here, 5 epochs), we train a few of the top layers in the base to extract the task-based features precisely. The bottom layers will remain untrained. This is called fine-tuning. 

len(base.layers)

Output:

There are 1039 layers in the base architecture. We freeze the bottom 700 layers and fine-tune the top layers.

 base.trainable = True
 for layer in base.layers[:700]:
   layer.trainable = False 

Compile the model once again to have the changes into effect.

 model.compile(loss='sparse_categorical_crossentropy',
               optimizer='adam',
               metrics=['accuracy']) 

Continue training from the 6th epoch. Train until 10th epoch.

 fine_tune = model.fit(train_batches, 
                       validation_data=val_batches,
                       initial_epoch=5,
                       epochs=10)  

Visualize the performance of the model.

 fine_hist = fine_tune.history
 loss += fine_hist['loss']
 val_loss += fine_hist['val_loss']
 acc += fine_hist['accuracy']
 val_acc += fine_hist['val_accuracy']
 plt.plot(loss, label='Train Loss')
 plt.plot(val_loss, label='Val Loss')
 plt.plot([4,4],[0.5,1.0], ‘--’, label='Fine Tuning Starts')
 plt.legend()
 plt.ylabel('Loss')
 plt.xlabel('Epochs')
 plt.show()  

Output:

fine-tuning pre-trained model
 plt.plot(acc, label='Train Accuracy')
 plt.plot(val_acc, label='Val Accuracy')
 plt.plot([4,4],[0.6,1.0], ‘--’, label='Fine Tuning Starts')
 plt.legend()
 plt.ylabel('Accuracy')
 plt.xlabel('Epochs')
 plt.show()  

Output:

fine-tuning pre-trained model

It is observed that model performance takes a sudden boost after fine tuning. Careful selection of the number of trainable layers, optimizer, learning rate and training configurations may lead to improved performance.

Wrapping Up

In this article, we have discussed Transfer Learning with image classification problems. We have explored the two Transfer Learning strategies with real-life examples using TensorFlow Keras. 

References and further reading

Rajkumar Lakshmanamoorthy
A geek in Machine Learning with a Master's degree in Engineering and a passion for writing and exploring new things. Loves reading novels, cooking, practicing martial arts, and occasionally writing novels and poems.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
MOST POPULAR