deep_learning

Build Multi-Class Image Classifier with Transfer Learning: TensorFlow Keras Tutorial for Beginners

Learn to build multi-class image classifiers using transfer learning with TensorFlow & Keras. Complete guide with code examples, data preprocessing & model optimization.

Build Multi-Class Image Classifier with Transfer Learning: TensorFlow Keras Tutorial for Beginners

Lately, I’ve noticed how image recognition has become part of our daily lives - from sorting vacation photos to medical diagnostics. This got me thinking: how can developers efficiently build accurate classifiers without massive datasets or computational power? Transfer learning provides the answer, and today I’ll show you how to implement it using TensorFlow and Keras.

We’ll create a flower classification system, but these techniques apply to any image recognition task. Why start from scratch when we can build on existing knowledge? Pre-trained models offer powerful feature extraction capabilities we can adapt to new problems with minimal resources. Let’s see how this works in practice.

First, we set up our environment. Make sure you have these dependencies installed:

pip install tensorflow matplotlib numpy pillow scikit-learn

Then import essential libraries:

import tensorflow as tf
from tensorflow.keras import layers, applications
import matplotlib.pyplot as plt

# Verify environment
print(f"TensorFlow: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

For our dataset, we’ll use Oxford Flowers - a collection of 3,670 images across 5 flower species. Here’s how we load and examine it:

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

# Create datasets
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=42,
    image_size=(224, 224),
    batch_size=32
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=42,
    image_size=(224, 224),
    batch_size=32
)

Notice the class imbalance? Daisies outnumber tulips nearly two-to-one. This matters because models might favor frequent classes. How can we address this? Data augmentation helps by artificially expanding our dataset through transformations:

data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.1)
])

# Apply to dataset
augmented_train = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y)
)

Now comes the transfer learning magic. We’ll use MobileNetV2 - a lightweight model pre-trained on ImageNet - as our foundation:

base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,
    weights="imagenet"
)

# Freeze base layers
base_model.trainable = False

# Build classification head
model = tf.keras.Sequential([
    tf.keras.Input(shape=(224, 224, 3)),
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.2),
    layers.Dense(len(class_names))
])

Why freeze the base model initially? This preserves the learned features while we train just the new classification layers. Later, we’ll selectively unfreeze layers for fine-tuning. But first, let’s train the top layers:

model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

history = model.fit(
    augmented_train,
    validation_data=val_ds,
    epochs=10
)

After initial training, we unfreeze the top layers of our base model for fine-tuning:

base_model.trainable = True
# Unfreeze last 20 layers
for layer in base_model.layers[:-20]:
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# Train again
history_fine = model.fit(
    augmented_train,
    validation_data=val_ds,
    epochs=5
)

Notice how we reduced the learning rate? This prevents overwriting valuable features while adjusting weights. How much improvement did we gain? Typically, fine-tuning boosts accuracy by 3-8% on this dataset. To evaluate properly, we need more than just accuracy scores:

# Generate predictions
predictions = model.predict(val_ds)
y_pred = np.argmax(predictions, axis=1)
y_true = np.concatenate([y for x, y in val_ds], axis=0)

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

This reveals where our model confuses similar flowers - perhaps mixing daisies and sunflowers. Such insights guide improvement efforts better than any single metric.

Finally, we save our trained model for deployment:

model.save("flower_classifier.keras")

Transfer learning democratizes computer vision. With just a few hundred images and basic hardware, we’ve built a robust classifier. What problems could you solve with these techniques? Medical imaging? Quality control? The possibilities are endless.

Found this practical? Share it with others starting their ML journey! Have questions or improvements? Let’s discuss in the comments - your insights make our community stronger.

Keywords: transfer learning tensorflow, multi class image classifier keras, tensorflow image classification tutorial, keras transfer learning model, deep learning image recognition, CNN image classifier python, tensorflow flowers dataset classification, pre-trained model fine tuning, computer vision tensorflow keras, machine learning image classification



Similar Posts
Blog Image
How to Build an Encoder-Decoder Model with Attention in PyTorch

Learn to build a production-ready encoder-decoder model with attention using PyTorch for translation and summarization tasks.

Blog Image
PyTorch Knowledge Distillation: Build 10x Faster Image Classification Models with Minimal Accuracy Loss

Learn to build efficient image classification models using knowledge distillation in PyTorch. Master teacher-student training, temperature scaling, and model compression techniques. Start optimizing today!

Blog Image
Build and Train Custom Vision Transformers in PyTorch: Complete Modern Image Classification Guide

Learn to build and train custom Vision Transformers (ViTs) in PyTorch with this complete guide covering patch embedding, attention mechanisms, and modern image classification techniques.

Blog Image
Complete TensorFlow Transfer Learning Guide: Build Image Classification Systems Fast

Learn to build a complete image classification system with transfer learning using TensorFlow and Keras. Master CNN architectures, custom layers, and deployment optimization techniques.

Blog Image
Build Multi-Modal Image-Text Classification with CLIP: Complete Python Fine-Tuning Guide for Custom AI Models

Learn to build advanced multi-modal image-text classification systems using CLIP and fine-tuning in Python. Master contrastive learning, zero-shot classification, and deployment techniques for real-world AI applications.

Blog Image
BERT Multi-Class Text Classification: Complete PyTorch Guide From Fine-Tuning to Production Deployment

Learn to build a complete multi-class text classification system with BERT and PyTorch. From fine-tuning to production deployment with FastAPI.