deep_learning

Build Custom ResNet from Scratch with PyTorch: Complete Guide to Skip Connections and Image Classification

Learn to build custom ResNet from scratch with PyTorch. Master skip connections, solve vanishing gradients, and implement deep image classification networks with hands-on code examples.

Build Custom ResNet from Scratch with PyTorch: Complete Guide to Skip Connections and Image Classification

The other day, I was trying to train a very deep neural network for an image recognition task. No matter what I did, the accuracy just wouldn’t improve past a certain point. My deeper models were performing worse than shallow ones, which was incredibly frustrating. This is a classic problem known as vanishing gradients, where the signal used to update the network’s weights literally fades away as it travels back through dozens of layers. It felt like hitting a wall. That’s when I decided to stop just using pre-built models and really understand the architecture that broke this barrier: the Residual Network, or ResNet.

Have you ever wondered how networks with over 100 layers are even possible to train? The answer lies in a beautifully simple concept called a skip connection. Instead of a layer trying to learn a complete transformation, it learns a residual function—the difference between its input and the desired output. This is achieved by adding the input of a block directly to its output. This simple act creates a direct pathway for gradients to flow backwards, preventing them from vanishing.

Let’s look at the core building block. The key equation is: output = F(x) + x. Here, x is the input, F(x) is what a few convolutional layers learn, and the sum is the output. If F(x) learns nothing useful, it can theoretically push its weights toward zero, and the block just outputs x. This means the network can easily learn an identity function, which is much harder in a standard sequential chain.

Here’s how we code a basic Residual Block in PyTorch.

import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # First convolution
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Second convolution
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # This handles cases where the input needs to be modified to match the output dimensions
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x  # This is the 'skip'
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(identity)  # The crucial addition
        out = self.relu(out)
        return out

Notice the shortcut path. When we change the number of channels or the spatial size (using stride=2), we can’t just add x to out directly. We use a 1x1 convolution to project the identity to the correct shape. This maintains the integrity of the skip connection.

Building the full network involves stacking these blocks into groups, gradually increasing the number of filters and reducing the spatial dimensions. We start with an initial convolution and pooling, then move through the residual layers.

What does training such a model actually feel like? You watch the loss drop more steadily than in a plain network of similar depth. Those skip connections are like express lanes for your gradients, ensuring every layer gets a meaningful update. To see it in action, let’s set up a quick training loop on a sample batch.

# Instantiate a small ResNet
def make_small_resnet():
    class SmallResNet(nn.Module):
        def __init__(self, num_classes=10):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
            self.bn1 = nn.BatchNorm2d(16)
            self.relu = nn.ReLU()
            # Layer 1: Two basic blocks
            self.layer1 = nn.Sequential(BasicBlock(16, 16), BasicBlock(16, 16))
            # Layer 2: Downsample
            self.layer2 = nn.Sequential(BasicBlock(16, 32, stride=2), BasicBlock(32, 32))
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(32, num_classes)
        def forward(self, x):
            x = self.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
            return x
    return SmallResNet()

model = make_small_resnet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Example training step on dummy data
data = torch.randn(8, 3, 32, 32)  # 8 images, 3 channels, 32x32 pixels
target = torch.randint(0, 10, (8,))
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"Loss after one step: {loss.item():.4f}")

This small example shows the mechanics. For real tasks like classifying CIFAR-10 or ImageNet, you would build the full architecture with many more layers, add data augmentation, and use a learning rate schedule. The principle remains identical.

Why does this approach work so well for image classification? It gives the network the flexibility to be as complex as it needs to be, without getting stuck. The skip connections act as a stabilizing force during training. They are the reason we can build models that are both deep and accurate, from recognizing objects in photos to analyzing medical scans.

My journey from hitting that accuracy wall to building a working ResNet taught me that sometimes the most powerful solutions are conceptually simple. The genius is in the design, not just the complexity. By implementing it yourself, you move from a user of tools to a true builder of AI systems.

I hope this walkthrough helps you add a powerful tool to your own projects. If you found this explanation useful, please share it with others who might be facing the same training hurdles. I’d love to hear about your experiments with custom architectures in the comments below—what problems are you trying to solve with deep learning?

Keywords: PyTorch ResNet tutorial, custom ResNet implementation, skip connections deep learning, residual networks from scratch, image classification PyTorch, vanishing gradient problem solution, neural network architecture building, computer vision deep learning, ResNet training tutorial, PyTorch CNN implementation



Similar Posts
Blog Image
Complete PyTorch Transfer Learning Pipeline: Custom Dataset to Production-Ready Image Classifier

Learn to build a complete image classification pipeline using PyTorch and transfer learning. Master data preparation, model fine-tuning, and deployment for real-world computer vision projects.

Blog Image
Build Custom CNN Models for Image Classification: TensorFlow Keras Tutorial with Advanced Training Techniques

Learn to build custom CNN models for image classification using TensorFlow and Keras. Complete guide with code examples, training tips, and optimization strategies.

Blog Image
How to Build a Production-Ready Named Entity Recognition (NER) System

Learn to build a fast, accurate, and scalable NER system using transformers, spaCy, and FastAPI for real-world applications.

Blog Image
Build Custom Vision Transformers in PyTorch: Complete Guide from Theory to Production Deployment

Learn to build and train custom Vision Transformers in PyTorch with this complete guide covering theory, implementation, training, and production deployment.

Blog Image
Complete Guide to Building Multi-Class Image Classifiers with TensorFlow Transfer Learning

Learn to build a multi-class image classifier using TensorFlow, Keras & transfer learning. Complete guide with preprocessing, fine-tuning & deployment tips.

Blog Image
Master TensorFlow Transfer Learning: Complete Image Classification Guide with Advanced Techniques

Learn to build powerful image classification systems with transfer learning using TensorFlow and Keras. Complete guide covering implementation, fine-tuning, and deployment strategies.