In this article, we will walk through a Pokemon classification example and provide a detailed breakdown of how to build a simplified ResNet18 model in PyTorch to classify five different types of Pokemon.
Project Overview
We will implement two key components:
- ResBlk: The basic residual block used to construct deeper ResNet architectures.
- ResNet18: A complete classification network built by stacking multiple ResBlk modules.
In addition, the main function demonstrates how to instantiate these modules and inspect their output shapes and parameter counts.
Code Explanation
Import Required Libraries
import torch
from torch import nn
from torch.nn import functional as F
torch: Core PyTorch library.nn: Contains various neural network layers and loss functions.functional: Provides stateless operations such as activation functions and convolution operations.
Defining the ResBlk Module
class ResBlk(nn.Module):
"""
ResNet Block (Residual Block)
"""
...
ResBlk is the fundamental building block of ResNet. By introducing residual (skip) connections, it helps alleviate the vanishing gradient problem in deep networks.
Initialization Method __init__
def __init__(self, ch_in, ch_out, stride=1):
"""
Initialize the ResBlk module
:param ch_in: Number of input channels
:param ch_out: Number of output channels
:param stride: Convolution stride
"""
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# Adjust dimensions with a 1x1 convolution when input/output channels differ
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
Explanation:
-
Convolution + BatchNorm layers:
conv1: A 3×3 convolution that convertsch_intoch_out, with stride =stride. Padding = 1 maintains spatial dimensions.bn1: Batch normalization applied afterconv1.conv2: A second 3×3 convolution, maintaining output channels.bn2: Batch normalization forconv2.
-
Residual (shortcut) connection:
- If
ch_in != ch_out, a 1×1 convolution is used to match channel dimensions. This adjustment is stored inself.extra. - If the channels match,
self.extraremains an empty sequential block, meaning the input is added directly to the output.
- If
Forward Method forward
def forward(self, x):
"""
Forward pass
:param x: Input tensor with shape [batch_size, ch_in, height, width]
:return: Output tensor with shape [batch_size, ch_out, height, width]
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# Residual connection
out = self.extra(x) + out
out = F.relu(out)
return out
Explanation:
-
Main path:
- Input
xgoes throughconv1 → bn1 → ReLU. - The intermediate output then goes through
conv2 → bn2.
- Input
-
Shortcut path:
- The original input
xis transformed byself.extra(if needed) to match the output channels. - The shortcut branch and the main branch are added element-wise.
- A final ReLU is applied.
- The original input
Defining the ResNet18 Model
class ResNet18(nn.Module):
def __init__(self, num_class):
"""
Initialize the ResNet18 model
:param num_class: Number of classification categories
"""
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
)
# Four residual blocks
self.blk1 = ResBlk(16, 32, stride=3)
self.blk2 = ResBlk(32, 64, stride=3)
self.blk3 = ResBlk(64, 128, stride=2)
self.blk4 = ResBlk(128, 256, stride=2)
# Fully connected layer
self.outlayer = nn.Linear(256 * 3 * 3, num_class)
Explanation:
-
Initial convolution:
- A 3×3 convolution with stride 3, expanding channels from 3 to 16.
- Batch normalization.
-
Residual block stack:
blk1toblk4progressively increase channel depth from 16 to 256.- Stride values control spatial downsampling.
-
Classifier head:
- The output of the final residual block is flattened and fed into a fully connected layer to produce class logits.
Forward Method
def forward(self, x):
"""
Forward pass
:param x: Input tensor [batch_size, 3, height, width]
:return: Output tensor [batch_size, num_class]
"""
x = F.relu(self.conv1(x))
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
x = x.view(x.size(0), -1) # Flatten
x = self.outlayer(x)
return x
Detailed Model Structure
Advantages of Residual Connections
Deep neural networks suffer from vanishing or exploding gradients as depth increases. ResNet addresses this by using residual connections, which allow gradients to flow more easily through the network, significantly improving optimization.
Structure of ResBlk
Each ResBlk contains:
- Main path: Two 3×3 convolutions with batch normalization.
- Shortcut path: Identity mapping or a 1×1 convolution to match dimensions.
This design lets the model learn residual functions, making training deeper networks more stable and efficient.
Layer Structure of ResNet18
-
Initial convolution
- Converts 3 channels to 16.
-
Four residual blocks
blk1: 16 → 32 (stride 3)blk2: 32 → 64 (stride 3)blk3: 64 → 128 (stride 2)blk4: 128 → 256 (stride 2)
-
Fully connected classifier
- Flattens the feature map into a vector size
256 × 3 × 3.
- Flattens the feature map into a vector size
Running the Code
Here is the complete implementation:
import torch
from torch import nn
from torch.nn import functional as F
...
# (full code preserved exactly as provided)
...
Explanation:
-
Testing ResBlk:
- A
ResBlkis instantiated to convert 64 channels to 128. - Passing a random tensor of shape
[2, 64, 224, 224]yields[2, 128, 224, 224].
- A
-
Testing ResNet18:
- Instantiate the model for 5 classes.
- A random input
[2, 3, 224, 224]produces output[2, 5].
-
Parameter count:
- Sum all parameters via
p.numel().
- Sum all parameters via
Run the script:
python resnet.py
Expected output:
block: torch.Size([2, 128, 224, 224])
resnet: torch.Size([2, 5])
parameters size: 1234885
Both the ResBlk and ResNet18 modules function correctly and produce the expected shapes and parameter counts.
