The other day, I tried to teach my computer to tell the difference between a daisy, a dandelion, and a tulip. I had a small collection of photos, far too few to train a modern image recognizer from scratch. The thought of needing millions of labeled images and weeks of computing time was daunting. That’s when I remembered a powerful shortcut: transfer learning. It’s the practice of taking a model already skilled at recognizing a thousand different things and repurposing it for your specific task. Today, I want to guide you through exactly how to do this with PyTorch, taking a model from an idea to a functioning application. Think of it as giving an expert artist a new, more specific subject to paint—they already know how to hold the brush.
Let’s begin by setting the stage. You’ll need PyTorch and Torchvision installed. I always start by setting a random seed for consistent results and checking if a GPU is available. This small step saves hours of debugging later.
import torch
import torch.nn as nn
from torchvision import models, transforms
# For reproducible results
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Now, for the data. A model is only as good as what it learns from. I organize my flower images into folders, one for each class. PyTorch’s ImageFolder can read this structure. But raw images aren’t ready; we need to resize them, convert them to tensors, and normalize their color values to match what the pre-trained model expects. We also use data augmentation—randomly flipping or rotating training images—to teach the model to recognize a flower no matter its orientation. Why does this simple trick work so well? It artificially expands our dataset, forcing the model to learn the essential features of a rose, not just the specific pose of one photo.
from torchvision import datasets
# Define transformations
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Load data
image_datasets = {
x: datasets.ImageFolder(f'path/to/your/data/{x}', data_transforms[x])
for x in ['train', 'val']
}
dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True)
for x in ['train', 'val']
}
With our data pipelines ready, we build the model. We’ll use a ResNet-18, a model pre-trained on the massive ImageNet dataset. Its early layers have learned to detect basic shapes like edges and textures—skills that are useful for almost any vision task. We keep those layers frozen and only replace the final fully-connected layer with a new one that outputs predictions for our specific number of flower classes.
# Load the pre-trained model
model = models.resnet18(weights='DEFAULT')
# Freeze all the parameters in the network
for param in model.parameters():
param.requires_grad = False
# Replace the final layer. The number of features in is model.fc.in_features
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5) # Let's say we have 5 flower classes
# Move the model to the GPU if available
model = model.to(device)
Training is next. We need a loss function to measure mistakes and an optimizer to correct them. Since we only train the new final layer, we tell the optimizer to only update the parameters of model.fc. We loop through our data in batches, let the model make predictions, calculate the loss, and nudge the weights in the right direction. Watching the validation accuracy climb after each epoch is incredibly satisfying. What do you think happens if we unfreeze more layers later in training?
criterion = nn.CrossEntropyLoss()
# Only parameters of the final layer are being optimized
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
# A simple training loop
num_epochs = 10
for epoch in range(num_epochs):
model.train() # Set model to training mode
for inputs, labels in dataloaders['train']:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# ... Validation phase would go here ...
After training, we must evaluate. Accuracy is a good start, but I also look at a confusion matrix. This shows if the model is consistently mixing up two specific types of flowers, which points to a need for more distinctive training data for those classes. Finally, to use the model in an app, we save its learned weights and write a simple function to process new images.
# Save the model state
torch.save(model.state_dict(), 'flower_classifier.pth')
# Function for inference
def predict_image(image_path, model, transform):
model.eval()
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
return predicted.item()
This journey from a folder of images to a working classifier demystifies a core technique in modern AI. You don’t need a supercomputer or a vast dataset; you can stand on the shoulders of giants. I built this to sort my garden photos, but the same principles apply to medical imaging or quality control on a factory line. What problem could you solve by teaching a model to see?
If this walkthrough helped you see the path forward, please share it with someone else who might be starting their own project. I’d love to hear what you’re building—drop a comment below and let’s discuss.