Image Classification with Pre-trained ResNet18 & PyTorch

Master image classification using PyTorch & pre-trained ResNet18. Learn transfer learning for powerful AI vision with limited data. Hands-on guide!

Hands-on: Image Classification Using Pre-trained ResNet18 with PyTorch

Image classification is a cornerstone of computer vision. Instead of training a Convolutional Neural Network (CNN) from scratch, we can leverage the power of pre-trained models like ResNet18, which have been trained on massive datasets such as ImageNet. This technique, known as transfer learning, allows us to achieve excellent results on real-world problems with limited data and computational resources.

What You'll Learn

  • How to load a pre-trained ResNet18 model in PyTorch.
  • Essential image preprocessing steps for CNNs.
  • How to perform inference (prediction) on an image.
  • How to visualize classification results.

Prerequisites

Ensure you have the following Python libraries installed:

pip install torch torchvision matplotlib Pillow

Example: Classifying an Image with ResNet18

This example walks you through the process of loading a pre-trained ResNet18 model, preprocessing an image, performing inference, and displaying the prediction.

1. Import Required Libraries

import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import json
import urllib.request

2. Load a Pre-trained ResNet18 Model

We'll load the ResNet18 model pre-trained on ImageNet and set it to evaluation mode.

# Load the ResNet18 model pre-trained on ImageNet
model = models.resnet18(pretrained=True)

# Set the model to evaluation mode
# This disables dropout and batch normalization updates, which are only used during training
model.eval()

3. Load and Preprocess the Image

PyTorch's pre-trained models, especially those trained on ImageNet, expect input images to be of a specific size (typically 224x224 pixels) and normalized using the mean and standard deviation of the ImageNet dataset.

# Define the image transformations
# 1. Resize the image to 256 pixels while maintaining aspect ratio
# 2. Crop the center 224x224 pixels
# 3. Convert the image to a PyTorch Tensor
# 4. Normalize the tensor with ImageNet's mean and standard deviation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean for RGB channels
        std=[0.229, 0.224, 0.225]     # ImageNet standard deviation for RGB channels
    )
])

# --- Image Loading ---
# Replace 'your_image.jpg' with the actual path to your image file
image_path = 'your_image.jpg'
try:
    # Load the image using Pillow and convert it to RGB format
    image = Image.open(image_path).convert('RGB')
except FileNotFoundError:
    print(f"Error: Image file not found at {image_path}")
    # Example: If you don't have an image, you can download one
    # url_example = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/Golden_Retriever_puppy.jpg/256px-Golden_Retriever_puppy.jpg"
    # img_data = urllib.request.urlopen(url_example).read()
    # image = Image.open(io.BytesIO(img_data)).convert('RGB')
    exit() # Exit if the image is not found and no fallback is implemented


# Apply the transformations and add a batch dimension
# The model expects input in the shape of (batch_size, channels, height, width)
input_tensor = transform(image).unsqueeze(0)

Note: Make sure to replace 'your_image.jpg' with the correct path to an image file on your system.

4. Perform Inference

We feed the preprocessed image tensor into the model and get the raw output. We then use .max(1) to find the index of the class with the highest probability.

# Perform inference
with torch.no_grad():  # Disable gradient calculation for inference
    outputs = model(input_tensor)
    # Get the index of the highest probability class
    _, predicted_index = outputs.max(1)

5. Map Class Index to Human-Readable Labels

The model outputs indices corresponding to ImageNet classes. We need to fetch a list of these class names to interpret the prediction.

# Load ImageNet class labels from a reliable source
# This URL points to a commonly used file containing ImageNet class names
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
class_labels = []
try:
    with urllib.request.urlopen(url) as f:
        class_labels = [line.decode('utf-8').strip() for line in f]
except Exception as e:
    print(f"Error fetching class labels: {e}")
    class_labels = ["Unknown"] * 1000 # Fallback if fetching fails

# Get the predicted class name using the predicted index
predicted_class_name = class_labels[predicted_index.item()]
print(f"Predicted class: {predicted_class_name}")

6. Display the Image and Prediction

Finally, we visualize the original image and overlay the predicted class label using Matplotlib.

# Display the image and the prediction
plt.imshow(image)
plt.title(f"Predicted: {predicted_class_name}")
plt.axis('off') # Hide axes
plt.show()

Output Example

Upon successful execution, you will see your input image displayed with its predicted class label, for instance:

Predicted class: Golden retriever

Summary

In this tutorial, you have learned how to:

  • Load and utilize a pre-trained ResNet18 model from PyTorch.
  • Apply essential image preprocessing techniques required by these models.
  • Execute image classification (inference) using the loaded model.
  • Interpret and visualize the model's predictions.

Why Use Pre-trained Models?

Leveraging pre-trained models offers significant advantages:

  • Faster Development: Significantly reduces training time and computational cost.
  • Improved Performance: Achieves better accuracy, especially with limited data, by benefiting from knowledge learned on large datasets.
  • State-of-the-Art Accuracy: Utilizes models that have been fine-tuned for broad visual recognition tasks.
  • Transferability: Easily adaptable to new, custom datasets through a process called fine-tuning.

Next Steps

  • Explore Other Models: Experiment with different pre-trained architectures like resnet50, vgg16, or mobilenet_v2.
  • Fine-tuning: Adapt a pre-trained model to your specific dataset by training it on your data, typically by updating only the final layers or all layers with a lower learning rate.
  • Build Applications: Integrate this image classification capability into web applications using frameworks like Flask or Streamlit.

SEO Keywords

Pretrained ResNet18 tutorial, Image classification with PyTorch, Transfer learning ResNet18, PyTorch image preprocessing, ResNet18 inference example, Visualize CNN predictions, ImageNet pretrained models, Fine-tune ResNet for classification, PyTorch classification pipeline, Load and predict with ResNet18.


Interview Questions

Here are some common interview questions related to this topic:

  1. What is transfer learning, and why is it useful in image classification?

    • Transfer learning is a machine learning technique where a model trained on one task is reused as the starting point for a model on a second task. It's useful in image classification because it allows us to leverage knowledge (features, patterns) learned from large datasets (like ImageNet) on models that have already achieved high performance. This is particularly beneficial when you have limited data or computational resources for your specific task, leading to faster training and often better results.
  2. How do you load and use a pre-trained ResNet18 model in PyTorch?

    • You load a pre-trained ResNet18 model using torchvision.models.resnet18(pretrained=True). After loading, it's crucial to set the model to evaluation mode using model.eval() before performing inference.
  3. Why do we normalize input images before feeding them to a pre-trained model?

    • Pre-trained models are trained with images normalized using specific mean and standard deviation values (often from the dataset they were trained on, like ImageNet). Normalization helps to:
      • Stabilize Training: Makes the gradients less likely to explode or vanish.
      • Improve Convergence: Helps the model converge faster.
      • Ensure Consistency: Matches the data distribution the model was originally trained on, improving prediction accuracy.
  4. What are the typical image preprocessing steps for models trained on ImageNet?

    • Typical steps include:
      • Resizing: Resizing the image to a larger dimension (e.g., 256 pixels) while maintaining the aspect ratio.
      • Cropping: Cropping the center part of the resized image to the target input size (e.g., 224x224 pixels).
      • Tensor Conversion: Converting the image data (usually from PIL Image or NumPy array) into a PyTorch Tensor.
      • Normalization: Subtracting the mean and dividing by the standard deviation for each color channel (RGB) using ImageNet's specific values.
  5. How can you perform inference on a single image using a pre-trained CNN?

    • The process involves:
      • Loading and preprocessing the image as described above.
      • Passing the preprocessed image tensor through the pre-trained model (model(input_tensor)).
      • Using torch.no_grad() context to disable gradient calculations, saving memory and computation.
      • Interpreting the output tensor (e.g., using .max(1) to find the class with the highest probability).
  6. How do you interpret the output of a pre-trained classification model?

    • The output is typically a tensor of logits or probabilities for each class. For a classification task with N classes, the output tensor will have a shape like (batch_size, N). To interpret it, you find the index with the highest value (using .max(1)), which corresponds to the predicted class. This index is then mapped to a human-readable class name using a predefined list of labels (like the ImageNet class list).
  7. What is the role of model.eval() in PyTorch?

    • model.eval() sets the model to evaluation mode. This is crucial because it affects certain layers like Dropout and BatchNorm:
      • Dropout: Dropout layers randomly set input units to zero during training. In evaluation mode, dropout is effectively turned off, and all units are used.
      • BatchNorm: BatchNorm layers update their running mean and variance during training. In evaluation mode, they use the learned running statistics instead of calculating batch statistics.
  8. How would you map model output indices to human-readable class names?

    • You need a mapping, typically a list or dictionary, where the index corresponds to the model's output index and the value is the human-readable class name. For ImageNet models, this mapping is usually obtained from a text file or a predefined list that comes with the dataset or is hosted online (like the imagenet_classes.txt file used in the example).
  9. What are the advantages of using pre-trained models compared to training from scratch?

    • Reduced Training Time & Resources: Significantly faster and requires less computational power.
    • Better Performance: Often achieves higher accuracy, especially on smaller datasets, due to knowledge transfer.
    • Feature Extraction: The initial layers of pre-trained models act as powerful feature extractors, capturing generic visual patterns.
    • Overfitting Mitigation: Less prone to overfitting on small datasets.
  10. How would you fine-tune a pre-trained CNN on your own dataset?

    • The typical steps for fine-tuning are:
      • Load a pre-trained model.
      • Replace the final classification layer(s) with new ones that match the number of classes in your dataset.
      • Freeze the weights of the early layers (optional but recommended initially) to retain the learned general features.
      • Train the new classification layers on your dataset.
      • Optionally, unfreeze some or all of the earlier layers and continue training with a very low learning rate to adapt the learned features to your specific data.