Supercharge Your Image Classifier: An Intro to Transfer Learning

Ever wanted to build a state-of-the-art image classifier without needing a supercomputer and a massive dataset? You’re in luck! The secret is transfer learning, a powerful technique that lets you “stand on the shoulders of giants.”
In this guide, I’ll show you how to build a highly accurate cat vs. dog classifier by leveraging a model that has already been trained by experts at Google. Let’s dive in!
Step 1: Loading the Dataset
First things first, we need data. We’ll use the popular cats_vs_dogs dataset, which we can load directly from TensorFlow Datasets. The images in this dataset come in various sizes, which is a common real-world scenario we’ll need to handle.
We’ll split our data into three sets: 80% for training, 10% for validation (to check our progress), and 10% for testing.
import tensorflow as tf
import tensorflow_datasets as tfds
# Split the data into training, validation, and testing sets
(raw_train, raw_validation, raw_test), metadata = tfds.load(
'cats_vs_dogs',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True, # Returns (image, label) pairs
)
Step 2: Prepping the Data for Our Model
Our model needs all images to be the exact same size and format. We’ll create a simple preprocessing function to:
- Resize every image to a standard 160x160 pixels.
- Normalize the pixel values. Instead of the usual 0-255 range, we’ll scale them to a -1 to 1 range, which is the format the pre-trained model expects.
IMG_SIZE = 160 # All images will be resized to 160x160
def format_example(image, label):
"""Resizes and normalizes images."""
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1 # Normalize to [-1, 1] range
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
Now, we’ll efficiently apply this function to all our datasets using .map(). Then, we’ll shuffle and batch the data to prepare it for training.
# Apply the formatting function to all datasets
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)
# Shuffle and batch the data
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
Step 3: Choosing a Pre-Trained Model 🧠
Here comes the magic of transfer learning. We’re going to use MobileNet V2, a powerful model developed by Google and pre-trained on the massive ImageNet dataset (1.4 million images, 1000 classes). This model already knows how to recognize a vast array of features, like edges, textures, and shapes.
We’ll download the model without its original classification layer (include_top=False), because we only need its powerful feature-extraction base.
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
# Create the base model from the pre-trained MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
Step 4: Freezing the Expert Brain
The weights of MobileNet V2 are already highly optimized. We don’t want to mess them up during our initial training. So, we’ll freeze the base model. This tells Keras not to update the weights of the MobileNet V2 layers, preserving all its hard-earned knowledge.
base_model.trainable = False
Step 5: Adding Our Custom Classifier
Now that we have our feature extractor, we need to add our own classifier on top.
- Pooling Layer: We’ll use a GlobalAveragePooling2D layer to condense the feature map from the base model into a single vector. It’s a modern and efficient alternative to just flattening.
- Output Layer: Since we only have two classes (cat or dog), we just need a single Dense neuron for our final prediction.
Let’s stack these layers together using a Sequential model.
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
Step 6: Training Our Souped-Up Model
It’s time to train! We’ll compile our model with a very low learning rate. This is crucial because we’re only training our new, small classifier, and we want to make sure we don’t make any drastic changes that would clash with the frozen base.
# Use a small learning rate
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# Train the model
history = model.fit(train_batches,
epochs=3,
validation_data=validation_batches)
After just a few epochs, you’ll see a surprisingly high accuracy—often over 95%! That’s the power of transfer learning.
Finally, you can save your trained model to use it anytime in the future without needing to retrain.
# Save the model for later use
model.save("dogs_vs_cats.h5")
And there you have it! You’ve successfully built a high-performance image classifier by leveraging a world-class pre-trained model. Happy coding!