A good deep learning model has a carefully carved architecture. It needs enormous training data, effective hardware, skilled developers, and a vast amount of time to train and hyper-tune the model to achieve satisfactory performance. Therefore, building a deep learning model from scratch and training is practically impossible for every deep learning task. Here comes the power of Transfer Learning. Transfer Learning is the approach of making use of an already trained model for a related task.
In this article, we discuss Transfer Learning with necessary examples to perform image classification using TensorFlow Keras. This article assumes that readers have good knowledge of the fundamentals of deep learning and computer vision. Readers may have a look at the following basic articles to fulfil the prerequisites:
- Getting Started With Deep Learning Using TensorFlow Keras
- Getting Started With Computer Vision Using TensorFlow Keras
How Transfer Learning works?
Each deep learning problem is unique in some sense. Therefore, it is hard to reuse an already trained model as such in new problems. It may need some task-specific alterations. Hence, we go for Transfer Learning to fulfil our task. The already trained model that is to be used via Transfer Learning is called a pre-trained model. A pre-trained model can be state-of-the-art in the domain. But it is necessary that our problem should belong to the same domain as that of the pre-trained model. For instance, a pre-trained model meant for image segmentation can not be utilized for image classification.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
Deep learning frameworks such as TensorFlow and PyTorch enable saving a model and its weights in portable formats. A carefully developed architecture can extract the features from the input data. This architecture, along with its weights, can extract features greatly from slightly different input data also. Thus, the main goal of transfer learning is to extract features details from a given data.
We explore two ways of applying Transfer Learning in the sequel:
- Feature Extraction only
- Feature Extraction and Fine-tuning
Feature Extraction Using a Pre-trained Model
Import the necessary frameworks and libraries.
import numpy as np import pandas as pd import tensorflow as tf from tensorflow import keras import tensorflow_datasets as tfds import matplotlib.pyplot as plt
TensorFlow Datasets has a huge collection of pre-processed and vectorized datasets from different domains. Here, we discuss feature extraction using transfer learning with image classification problems. Load in-built Oxford_Flowers102 dataset that has images of flowers from 102 different classes.
data, meta = tfds.load('oxford_flowers102', as_supervised=True, with_info=True)
Output:
Obtain train, validation and test sets from the data.
raw_train = data['train'] raw_val = data['validation'] raw_test = data['test']
Sample an image and display it with its label.
label_extractor = meta.features['label'].int2str for image,label in raw_train.take(1): plt.imshow(image) plt.title(label_extractor(label)) plt.colorbar() plt.show()
Output:
Visual observation yields that the image is 3 channel colour image with pixel values ranging from 0 to 255. However, we can get exact bounding values using the following codes.
print(image.shape) min(image.numpy().ravel()), max(image.numpy().ravel())
Output:
We get our problem (the image classification dataset). We need to choose a suitable pre-trained model. ImageNet is one of the most famous datasets used in image classification. It has more than a million images belonging to 1000 classes. Newly developed competing architectures are trained and tested with this dataset. NASNetLarge, developed purely by Google’s reinforcement learning environment (NAS- Neural Architecture Search) without human intervention, is one among the popular architectures.
The convolution neural network part in the architecture is called the base, and the artificial neural network part (with Dense layers) is called the head. Base extracts features from the input images. Head makes classification using the extracted features. ImageNet has 1000 classes, but our dataset has only 102 classes. Hence, we can use only the base of the pre-trained NASNetLarge model as a feature extractor.
base = keras.applications.NASNetLarge(input_shape=(331,331,3), include_top=False, weights='imagenet')
Output:
NASNetLarge expects its input to be in the shape of (331,331,3). We need to resize our images to conform to the requirements. Define a helper function to scale and resize the input images.
def scale_image(img,label): img = tf.cast(img,tf.float32) img = img/255.0 img = tf.image.resize(img, (331,331)) return img,label
Scale and resize the train, validation and test images.
train = raw_train.map(scale_image) val = raw_val.map(scale_image) test = raw_test.map(scale_image)
Sample some 25 images and display them with their text labels.
plt.figure(figsize=(8,8)) i=1 for img, label in train.take(25): plt.subplot(5,5,i) plt.imshow(img) plt.title(label_extractor(label)) plt.xticks([]) plt.yticks([]) i += 1 plt.show()
Output:
Prepare data in batches as the optimizer expects it.
batch_size = 64 train_batches = train.batch(batch_size) val_batches = val.batch(batch_size)
We use the pre-trained model’s base and its weights as such. Hence we should not train them again.
base.trainable=False
Develop a classification head to classify 102 classes.
head = keras.models.Sequential([ keras.layers.GlobalAveragePooling2D(), keras.layers.Dense(512, activation='relu'), keras.layers.Dropout(0.5), keras.layers.Dense(4, activation='softmax') ])
Build the final model with the non-trainable base (feature extractor) and the head.
model = keras.models.Sequential([base,head])
We can explore the number of parameters in the model.
base.summary()
A portion of the output:
The base has around 85 million parameters, none of which are trainable (pre-trained model).
head.summary()
Output:
The head has around 2 million parameters, everything being trainable.
model.summary()
Output:
We compile the model with a suitable loss function, an optimizer and an evaluation metric.
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
Our job is to train the head of the model with the input data while the base remains as such. We can define an early stopping callback. It stops training when there is no remarkable improvement in the validation performance. Here, we provide 3 as the value to the argument patience
. It is the number of epochs for which the training will continue even if there is no improvement in performance. For highly unstable performance (zig-zag performance curves), higher patience is preferred.
callback = keras.callbacks.EarlyStopping(monitor='val_loss',patience=3)
Train the model for 100 epochs. The early stopping callback will break training at some early epoch itself.
history = model.fit(train_batches, validation_data=val_batches, epochs=100, callbacks=callback)
A portion of the output:
We observe that the training has stopped just after the 30th epoch due to a decline in validation loss. Visualize the losses and accuracies to get a better insight about training.
hist = pd.DataFrame(history.history) length = len(hist['loss']) epochs = np.arange(1,length+1) plt.plot(epochs,hist['loss'], label='Train Loss') plt.plot(epochs,hist['val_loss'], label='Val Loss') plt.legend() plt.ylabel('Loss') plt.xlabel('Epochs') plt.xticks(np.arange(1,length+1,2)) plt.show()
Output:
plt.plot(epochs,hist['accuracy'], label='Train Accuracy') plt.plot(epochs,hist['val_accuracy'], label='Val Accuracy') plt.legend() plt.ylabel('Accuracy') plt.xlabel('Epochs') plt.xticks(np.arange(1,length+1,2)) plt.show()
Output:
Both training and validation performances get saturated at around 10th epoch. There is no remarkable improvement afterwards. This is exactly because the base was originally trained to extract features from ImageNet dataset. Our dataset may have some minor feature differences that the base can not extract. We handle this issue with another dataset in the sequel.
Fine-tuning a Pre-trained Model
When the present image features differ from originally trained images, we lack performances, even though the pre-trained model is a well-acclaimed state-of-the-art. We look at another dataset here, the Cassava Leaf Disease dataset, available in-built with TensorFlow Datasets. This problem has leaves in each image but with minute differences according to their disease states.
cassava, meta = tfds.load('cassava', with_info=True, as_supervised=True)
Output:
Obtain the split data.
raw_train = cassava['train'] raw_val = cassava['validation'] raw_test = cassava['test']
Metadata gives the details about the dataset.
meta
Output:
The dataset has 5 classes: one healthy and four different disease classes.
for i in range(5): print(f'{i}: {label_extractor(i)}')
Output:
Sample an image and visualize it.
label_extractor = meta.features['label'].int2str for img, label in raw_train.skip(1).take(1): plt.imshow(img) plt.title(label_extractor(label)) plt.colorbar() plt.show()
print(img.shape) min(img.numpy().ravel()), max(img.numpy().ravel())
Output:
The images are 3 channel colour images with pixel values ranging from 0 to 255 as before. Scale and resize the images.
train = raw_train.map(scale_image) val = raw_val.map(scale_image) test = raw_test.map(scale_image)
Sample some 25 resized images and visualize them along with their labels.
plt.figure(figsize=(8,8)) i=1 for img, label in train.take(25): plt.subplot(5,5,i) plt.imshow(img) plt.title(label_extractor(label)) plt.xticks([]) plt.yticks([]) i += 1 plt.show()
Output:
Prepare batches of datasets.
batch_size = 128 train_batches = train.batch(batch_size) val_batches = val.batch(batch_size)
Use the same feature extractor base. But, develop a new head to classify 5 classes.
base.trainable=False head = keras.models.Sequential([ keras.layers.GlobalAveragePooling2D(), keras.layers.Dense(256, activation='relu'), keras.layers.Dropout(0.5), keras.layers.Dense(5, activation='softmax') ]) model = keras.models.Sequential([base,head])
Compile the model and visualize the number of parameters.
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary()
Output:
As before, the base parameters are non-trainable, and the head parameters are trainable. Train the model for 5 epochs only.
history = model.fit(train_batches, validation_data=val_batches, epochs=5)
Visualize the performance.
hist = history.history loss = hist['loss'] val_loss = hist['val_loss'] acc = hist['accuracy'] val_acc = hist['val_accuracy'] plt.plot(loss, label='Train Loss') plt.plot(val_loss, label='Val Loss') plt.legend() plt.ylabel('Loss') plt.xlabel('Epochs') plt.show()
Output:
plt.plot(acc, label='Train Accuracy') plt.plot(val_acc, label='Val Accuracy') plt.legend() plt.ylabel('Accuracy') plt.xlabel('Epochs') plt.show()
Output:
After some initial training (here, 5 epochs), we train a few of the top layers in the base to extract the task-based features precisely. The bottom layers will remain untrained. This is called fine-tuning.
len(base.layers)
Output:
There are 1039 layers in the base architecture. We freeze the bottom 700 layers and fine-tune the top layers.
base.trainable = True for layer in base.layers[:700]: layer.trainable = False
Compile the model once again to have the changes into effect.
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
Continue training from the 6th epoch. Train until 10th epoch.
fine_tune = model.fit(train_batches, validation_data=val_batches, initial_epoch=5, epochs=10)
Visualize the performance of the model.
fine_hist = fine_tune.history loss += fine_hist['loss'] val_loss += fine_hist['val_loss'] acc += fine_hist['accuracy'] val_acc += fine_hist['val_accuracy'] plt.plot(loss, label='Train Loss') plt.plot(val_loss, label='Val Loss') plt.plot([4,4],[0.5,1.0], ‘--’, label='Fine Tuning Starts') plt.legend() plt.ylabel('Loss') plt.xlabel('Epochs') plt.show()
Output:
plt.plot(acc, label='Train Accuracy') plt.plot(val_acc, label='Val Accuracy') plt.plot([4,4],[0.6,1.0], ‘--’, label='Fine Tuning Starts') plt.legend() plt.ylabel('Accuracy') plt.xlabel('Epochs') plt.show()
Output:
It is observed that model performance takes a sudden boost after fine tuning. Careful selection of the number of trainable layers, optimizer, learning rate and training configurations may lead to improved performance.
Wrapping Up
In this article, we have discussed Transfer Learning with image classification problems. We have explored the two Transfer Learning strategies with real-life examples using TensorFlow Keras.