Deep learning models—especially large neural networks—typically require vast amounts of data and computational resources to train. However, collecting large labeled datasets is often expensive, and in some domains (such as medical imaging or remote sensing), even obtaining raw data can be extremely difficult. In addition, training a deep model from scratch consumes substantial time and computing power, which is often impractical in real-world scenarios.
Transfer learning leverages knowledge learned from one task to accelerate the learning process of another related task. Its main advantages include:
- Reduced data requirements: By using a pre-trained model, strong performance can be achieved even with a relatively small dataset.
- Shorter training time: Since a pre-trained model has already learned rich feature representations, the fine-tuning process becomes significantly faster.
- Improved performance: Transfer learning often leads to better generalization, especially when the target task has limited data.
Fundamentals of Transfer Learning
The core idea of transfer learning is the transfer of knowledge. Specifically, transfer learning typically proceeds through the following steps:
- Pre-trained model: Train a deep neural network on a large dataset (e.g., ImageNet) to learn general-purpose feature representations.
- Feature transfer: Retain the earlier layers of the pre-trained model (which capture low-level features such as edges and textures). These features are usually universal across tasks.
- Task-specific layers: Modify or replace the final layers of the model according to the requirements of the new task (e.g., changing the number of output classes).
- Fine-tuning: Train the model on the dataset of the new task, allowing the model to adapt to task-specific characteristics while preserving general knowledge.
The key to effective transfer learning lies in choosing an appropriate pre-trained model and adjusting the architecture so that knowledge is transferred effectively.
Transfer Learning in PyTorch
As a flexible and powerful deep learning framework, PyTorch offers extensive support for transfer learning. Below is an example illustrating how to implement transfer learning in PyTorch.
Environment Setup
First, make sure PyTorch and its dependencies are installed:
pip install torch torchvision
Loading a Pre-trained Model
PyTorch’s torchvision.models module provides a variety of pre-trained models such as ResNet, VGG, Inception, and more. Using ResNet18 as an example:
import torch
import torchvision.models as models
# Load a pre-trained ResNet18
model = models.resnet18(pretrained=True)
Modifying the Model Architecture
Depending on the target task, you usually need to modify the final fully connected layer (the classifier). For example, if the new task has 10 output classes:
import torch.nn as nn
# Get the number of input features of the FC layer
num_ftrs = model.fc.in_features
# Replace the final FC layer
model.fc = nn.Linear(num_ftrs, 10)
Choosing Which Layers to Train
Transfer learning is typically applied in two ways:
- Freeze pre-trained layers: Train only the newly added layers while keeping the pre-trained layers fixed.
- Fine-tune the entire network: Continue training all layers so that the model adapts more fully to the new task.
Example: freeze all pre-trained layers:
for param in model.parameters():
param.requires_grad = False
# Train only the new FC layer
for param in model.fc.parameters():
param.requires_grad = True
Training and Fine-Tuning
Define a loss function and optimizer, then begin training:
import torch.optim as optim
# Optimize only the final FC layer
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
num_epochs = 25
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(train_dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
# Validation step can be added here
Summary
Transfer learning significantly reduces the barriers to training deep neural networks, particularly when data is scarce or computational resources are limited. PyTorch provides a clean and powerful interface for loading, modifying, and training pre-trained models. By mastering the principles and implementation of transfer learning, you can make your deep learning projects far more efficient and effective.
