- Import Required Libraries and Set Hyperparameters
- Data Loading and Preprocessing
- Define Evaluation Function
- Training, Validation, and Testing Pipeline
- Model Saving and Loading
- Visualizing with Visdom
- Run the Code
In the first two posts of this series, we introduced how to create a custom Dataset class using PyTorch to load and preprocess a Pokemon image dataset, and how to build a simplified ResNet18 model. In this post, we will walk you through how to train, validate, and test this model to accomplish efficient image classification.
- Data Preparation: Load the training, validation, and test sets using the
Pokemondataset class defined previously. - Model Definition: Use the custom-built ResNet18 model.
- Training Process: Define the loss function and optimizer, and perform model training.
- Validation & Testing: Evaluate the model’s performance on the validation set and test set.
- Visualization: Use Visdom to monitor loss and accuracy in real time during training.
Import Required Libraries and Set Hyperparameters
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
from .pokemon import Pokemon
from .resnet import ResNet18
batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)
-
Library Imports:
torch,optim,nn: Core PyTorch modules for building and training neural networks.visdom: Real-time visualization for monitoring training.DataLoader: Used to load data in batches.Pokemon,ResNet18: Custom dataset and network classes defined in earlier posts.
-
Parameter Settings:
batchsz: Number of samples per training batch.lr: Learning rate controlling optimization step size.epochs: Total number of training epochs.
-
Device Configuration:
device: Use GPU if available.torch.manual_seed(1234): Set random seed for reproducibility.
Data Loading and Preprocessing
train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
-
Instantiate the Datasets:
train_db,val_db,test_dbrepresent the training, validation, and test sets.- The
Pokemonclass splits datasets automatically based on themodeargument.
-
Create Data Loaders:
DataLoaderbundles the dataset into iterable mini-batches.shuffle=Truefor training to improve generalization.num_workersspecifies how many subprocesses load data in parallel.
Define Evaluation Function
def evalute(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
-
Purpose of the Function:
evalutecomputes accuracy over the given data loader (validation or test).
-
Implementation Details:
model.eval()sets the model to evaluation mode, disabling dropout and batch norm updates.- Predictions are computed without gradient tracking (
torch.no_grad()). argmaxselects the class with the highest predicted score.- Returns accuracy = correct predictions / total samples.
Training, Validation, and Testing Pipeline
def main():
model = ResNet18(5).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
# x: [b, 3, 224, 224], y: [b]
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(), 'best.mdl')
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evalute(model, test_loader)
print('test acc:', test_acc)
-
Model Initialization:
ResNet18(5)builds a model with 5 output classes.
-
Optimizer and Loss Function:
- Adam optimizer with learning rate
1e-3. - CrossEntropyLoss for multi-class classification.
- Adam optimizer with learning rate
-
Visualization Setup:
- Initialize loss and accuracy charts in Visdom.
-
Validation and Model Saving:
- After each epoch, compute validation accuracy.
- Save the best model to
best.mdl.
-
Load Best Model & Test:
- Load the best-performing checkpoint.
- Evaluate on the test set.
Model Saving and Loading
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(), 'best.mdl')
viz.line([val_acc], [global_step], win='val_acc', update='append')
-
Saving:
- Update best epoch and accuracy.
- Save model parameters via
torch.save.
-
Loading:
- Later retrieved using
model.load_state_dict(torch.load('best.mdl')).
- Later retrieved using
Visualizing with Visdom
viz = visdom.Visdom()
# In main()
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
...
viz.line([loss.item()], [global_step], win='loss', update='append')
...
viz.line([val_acc], [global_step], win='val_acc', update='append')
- Initialize Visdom Client
- Create Visualization Windows for loss and accuracy.
- Update Charts During Training using
update='append'.
Usage
Start Visdom server:
python -m visdom.server
Then access http://localhost:8097 to view real-time training metrics.
Run the Code
- Start Visdom:
python -m visdom.server
- Run Training Script:
python train_scratch.py
- View Results:
Open: http://localhost:8097
Expected Output
best acc: 0.9098712446351931 best epoch: 6
loaded from ckpt!
test acc: 0.8841201716738197

