deep_learning

Build Custom ResNet from Scratch with PyTorch: Complete Guide to Skip Connections and Image Classification

Learn to build custom ResNet from scratch with PyTorch. Master skip connections, solve vanishing gradients, and implement deep image classification networks with hands-on code examples.

Build Custom ResNet from Scratch with PyTorch: Complete Guide to Skip Connections and Image Classification

The other day, I was trying to train a very deep neural network for an image recognition task. No matter what I did, the accuracy just wouldn’t improve past a certain point. My deeper models were performing worse than shallow ones, which was incredibly frustrating. This is a classic problem known as vanishing gradients, where the signal used to update the network’s weights literally fades away as it travels back through dozens of layers. It felt like hitting a wall. That’s when I decided to stop just using pre-built models and really understand the architecture that broke this barrier: the Residual Network, or ResNet.

Have you ever wondered how networks with over 100 layers are even possible to train? The answer lies in a beautifully simple concept called a skip connection. Instead of a layer trying to learn a complete transformation, it learns a residual function—the difference between its input and the desired output. This is achieved by adding the input of a block directly to its output. This simple act creates a direct pathway for gradients to flow backwards, preventing them from vanishing.

Let’s look at the core building block. The key equation is: output = F(x) + x. Here, x is the input, F(x) is what a few convolutional layers learn, and the sum is the output. If F(x) learns nothing useful, it can theoretically push its weights toward zero, and the block just outputs x. This means the network can easily learn an identity function, which is much harder in a standard sequential chain.

Here’s how we code a basic Residual Block in PyTorch.

import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # First convolution
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Second convolution
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # This handles cases where the input needs to be modified to match the output dimensions
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x  # This is the 'skip'
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(identity)  # The crucial addition
        out = self.relu(out)
        return out

Notice the shortcut path. When we change the number of channels or the spatial size (using stride=2), we can’t just add x to out directly. We use a 1x1 convolution to project the identity to the correct shape. This maintains the integrity of the skip connection.

Building the full network involves stacking these blocks into groups, gradually increasing the number of filters and reducing the spatial dimensions. We start with an initial convolution and pooling, then move through the residual layers.

What does training such a model actually feel like? You watch the loss drop more steadily than in a plain network of similar depth. Those skip connections are like express lanes for your gradients, ensuring every layer gets a meaningful update. To see it in action, let’s set up a quick training loop on a sample batch.

# Instantiate a small ResNet
def make_small_resnet():
    class SmallResNet(nn.Module):
        def __init__(self, num_classes=10):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
            self.bn1 = nn.BatchNorm2d(16)
            self.relu = nn.ReLU()
            # Layer 1: Two basic blocks
            self.layer1 = nn.Sequential(BasicBlock(16, 16), BasicBlock(16, 16))
            # Layer 2: Downsample
            self.layer2 = nn.Sequential(BasicBlock(16, 32, stride=2), BasicBlock(32, 32))
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(32, num_classes)
        def forward(self, x):
            x = self.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return x
    return SmallResNet()

model = make_small_resnet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Example training step on dummy data
data = torch.randn(8, 3, 32, 32)  # 8 images, 3 channels, 32x32 pixels
target = torch.randint(0, 10, (8,))
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"Loss after one step: {loss.item():.4f}")

This small example shows the mechanics. For real tasks like classifying CIFAR-10 or ImageNet, you would build the full architecture with many more layers, add data augmentation, and use a learning rate schedule. The principle remains identical.

Why does this approach work so well for image classification? It gives the network the flexibility to be as complex as it needs to be, without getting stuck. The skip connections act as a stabilizing force during training. They are the reason we can build models that are both deep and accurate, from recognizing objects in photos to analyzing medical scans.

My journey from hitting that accuracy wall to building a working ResNet taught me that sometimes the most powerful solutions are conceptually simple. The genius is in the design, not just the complexity. By implementing it yourself, you move from a user of tools to a true builder of AI systems.

I hope this walkthrough helps you add a powerful tool to your own projects. If you found this explanation useful, please share it with others who might be facing the same training hurdles. I’d love to hear about your experiments with custom architectures in the comments below—what problems are you trying to solve with deep learning?

Keywords: PyTorch ResNet tutorial, custom ResNet implementation, skip connections deep learning, residual networks from scratch, image classification PyTorch, vanishing gradient problem solution, neural network architecture building, computer vision deep learning, ResNet training tutorial, PyTorch CNN implementation



Similar Posts
Blog Image
Build CLIP Multi-Modal Image-Text Classification System with PyTorch: Complete Tutorial Guide

Learn to build a powerful multi-modal image-text classification system using CLIP and PyTorch. Complete tutorial with contrastive learning, zero-shot capabilities, and deployment strategies. Start building today!

Blog Image
Build Multi-Class Image Classifier with TensorFlow Transfer Learning and Fine-Tuning Complete Guide

Learn to build powerful multi-class image classifiers using TensorFlow transfer learning and fine-tuning techniques. Complete tutorial with code examples.

Blog Image
Master Custom CNN Architecture Design with PyTorch: Complete Image Classification Tutorial with Modern Techniques

Learn to build and train custom CNN architectures with PyTorch for image classification. Complete guide covering design, implementation, optimization, and evaluation techniques.

Blog Image
Build and Fine-Tune Vision Transformers for Image Classification with PyTorch: Complete Tutorial

Learn how to build and fine-tune Vision Transformers (ViTs) for image classification using PyTorch. Master ViT architecture, training techniques, and optimization strategies.

Blog Image
Complete Guide to Building Custom Neural Networks with PyTorch: Model Subclassing and Advanced Training Techniques

Master PyTorch neural networks with custom model subclassing, advanced training techniques, and optimization strategies. Build from scratch with practical examples.

Blog Image
Build Custom Vision Transformers with PyTorch: Complete Training and Implementation Guide

Learn to build custom Vision Transformers from scratch using PyTorch. Complete guide covers ViT architecture, training, transfer learning & deployment.