deep_learning

Transfer Learning Image Classification: Build Multi-Class Classifiers with PyTorch ResNet Complete Tutorial

Learn to build powerful multi-class image classifiers using PyTorch transfer learning and ResNet. Complete guide with code examples, data augmentation tips, and model optimization techniques.

Transfer Learning Image Classification: Build Multi-Class Classifiers with PyTorch ResNet Complete Tutorial

I’ve always been fascinated by how computers can learn to recognize images, but training models from scratch felt like reinventing the wheel. That’s why I fell in love with transfer learning—it’s like standing on the shoulders of giants. Today, I want to show you how to build a powerful image classifier without starting from zero. Why now? Because I’ve seen too many developers struggle with limited data, and transfer learning solves this beautifully.

Have you ever wondered how to classify hundreds of image categories with just a fraction of the usual data? Transfer learning makes this possible by using models already trained on massive datasets. We’ll use ResNet, a model that learned from over a million images, and adapt it for your specific needs.

Let me walk you through the process. First, we set up our environment. You’ll need PyTorch and torchvision. Here’s how to install them:

pip install torch torchvision matplotlib pillow

Now, let’s import the necessary libraries. I always start with these basics:

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

Data preparation is crucial. I organize images in folders named by class, like “cats”, “dogs”, etc. This makes loading straightforward. Here’s a simple dataset class I often use:

class ImageDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.class_names)}
        
        for class_name in self.class_names:
            class_dir = os.path.join(data_dir, class_name)
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_name)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

What if your images are different sizes or lighting conditions? Data augmentation helps. I use transforms to make the model robust:

train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Now, the fun part—loading a pre-trained ResNet. I usually start with ResNet18 for speed, but you can choose larger versions for better accuracy. Here’s how I modify it for my classes:

model = models.resnet18(pretrained=True)
num_classes = 10  # Change this to your number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

Did you know that freezing early layers can save training time? I freeze the base layers and only train the new classifier:

for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

Training involves defining a loss function and optimizer. I prefer cross-entropy loss and Adam optimizer:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

During training, I monitor accuracy and loss. Here’s a snippet from my training loop:

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')

After training, evaluate on a validation set. I check accuracy like this:

correct = 0
total = 0
model.eval()
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy: {100 * correct / total:.2f}%')

What happens if your model isn’t accurate enough? You might fine-tune more layers or adjust the learning rate. I often use learning rate schedulers to improve performance.

Finally, save your model for future use:

torch.save(model.state_dict(), 'image_classifier.pth')

I hope this guide helps you build your own image classifier efficiently. Transfer learning has saved me countless hours, and I’m excited to see what you create. If you found this useful, please like, share, and comment with your experiences or questions. Let’s learn together

Keywords: multi-class image classifier, transfer learning PyTorch, ResNet image classification, PyTorch transfer learning tutorial, deep learning image recognition, computer vision PyTorch, pre-trained ResNet models, image classification tutorial, PyTorch CNN training, machine learning image classifier



Similar Posts
Blog Image
Complete Guide: Building Multi-Class Image Classifier with TensorFlow Transfer Learning

Learn to build powerful multi-class image classifiers using transfer learning with TensorFlow and Keras. Complete guide with MobileNetV2, data preprocessing, and optimization techniques for better accuracy with less training data.

Blog Image
Build Custom Vision Transformers with PyTorch: Complete Architecture to Production Deployment Guide

Learn to build custom Vision Transformers with PyTorch from scratch. Complete guide covering architecture, training, optimization, and production deployment. Start building ViTs today!

Blog Image
From Encoder-Decoder to Attention: How Machines Learn Human Language

Explore how encoder-decoder models and attention mechanisms revolutionized machine understanding of human language. Learn the core ideas and architecture.

Blog Image
How to Build a Neural Machine Translation System with Transformers

Learn how modern translation systems work using Transformers, attention, and PyTorch. Build your own translator from scratch today.

Blog Image
Build Real-Time Emotion Recognition with PyTorch and OpenCV: Complete Deep Learning Tutorial

Learn to build real-time emotion recognition with PyTorch and OpenCV. Complete tutorial covering CNN architecture, data preprocessing, model training, and deployment optimization for facial expression classification.

Blog Image
BERT Multi-Class Text Classification: Complete PyTorch Guide From Fine-Tuning to Production Deployment

Learn to build a complete multi-class text classification system with BERT and PyTorch. From fine-tuning to production deployment with FastAPI.