Pokemon (III) - Training, Validation, and Testing
Pokemon (III) - Training, Validation, and Testing

Pokemon (III) - Training, Validation, and Testing

in
  1. Import Required Libraries and Set Hyperparameters
  2. Data Loading and Preprocessing
  3. Define Evaluation Function
  4. Training, Validation, and Testing Pipeline
  5. Model Saving and Loading
  6. Visualizing with Visdom
    1. Usage
  7. Run the Code
    1. Expected Output

In the first two posts of this series, we introduced how to create a custom Dataset class using PyTorch to load and preprocess a Pokemon image dataset, and how to build a simplified ResNet18 model. In this post, we will walk you through how to train, validate, and test this model to accomplish efficient image classification.

  1. Data Preparation: Load the training, validation, and test sets using the Pokemon dataset class defined previously.
  2. Model Definition: Use the custom-built ResNet18 model.
  3. Training Process: Define the loss function and optimizer, and perform model training.
  4. Validation & Testing: Evaluate the model’s performance on the validation set and test set.
  5. Visualization: Use Visdom to monitor loss and accuracy in real time during training.

Import Required Libraries and Set Hyperparameters

import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader

from .pokemon import Pokemon
from .resnet import ResNet18

batchsz = 32
lr = 1e-3
epochs = 10

device = torch.device('cuda')
torch.manual_seed(1234)
  1. Library Imports:

    • torch, optim, nn: Core PyTorch modules for building and training neural networks.
    • visdom: Real-time visualization for monitoring training.
    • DataLoader: Used to load data in batches.
    • Pokemon, ResNet18: Custom dataset and network classes defined in earlier posts.
  2. Parameter Settings:

    • batchsz: Number of samples per training batch.
    • lr: Learning rate controlling optimization step size.
    • epochs: Total number of training epochs.
  3. Device Configuration:

    • device: Use GPU if available.
    • torch.manual_seed(1234): Set random seed for reproducibility.

Data Loading and Preprocessing

train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                          num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  1. Instantiate the Datasets:

    • train_db, val_db, test_db represent the training, validation, and test sets.
    • The Pokemon class splits datasets automatically based on the mode argument.
  2. Create Data Loaders:

    • DataLoader bundles the dataset into iterable mini-batches.
    • shuffle=True for training to improve generalization.
    • num_workers specifies how many subprocesses load data in parallel.

Define Evaluation Function

def evalute(model, loader):
    model.eval()

    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total
  1. Purpose of the Function:

    • evalute computes accuracy over the given data loader (validation or test).
  2. Implementation Details:

    • model.eval() sets the model to evaluation mode, disabling dropout and batch norm updates.
    • Predictions are computed without gradient tracking (torch.no_grad()).
    • argmax selects the class with the highest predicted score.
    • Returns accuracy = correct predictions / total samples.

Training, Validation, and Testing Pipeline

def main():
    model = ResNet18(5).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)

            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)
  1. Model Initialization:

    • ResNet18(5) builds a model with 5 output classes.
  2. Optimizer and Loss Function:

    • Adam optimizer with learning rate 1e-3.
    • CrossEntropyLoss for multi-class classification.
  3. Visualization Setup:

    • Initialize loss and accuracy charts in Visdom.
  4. Validation and Model Saving:

    • After each epoch, compute validation accuracy.
    • Save the best model to best.mdl.
  5. Load Best Model & Test:

    • Load the best-performing checkpoint.
    • Evaluate on the test set.

Model Saving and Loading

if val_acc > best_acc:
    best_epoch = epoch
    best_acc = val_acc

    torch.save(model.state_dict(), 'best.mdl')

    viz.line([val_acc], [global_step], win='val_acc', update='append')
  1. Saving:

    • Update best epoch and accuracy.
    • Save model parameters via torch.save.
  2. Loading:

    • Later retrieved using model.load_state_dict(torch.load('best.mdl')).

Visualizing with Visdom

viz = visdom.Visdom()

# In main()
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
...
viz.line([loss.item()], [global_step], win='loss', update='append')
...
viz.line([val_acc], [global_step], win='val_acc', update='append')
  1. Initialize Visdom Client
  2. Create Visualization Windows for loss and accuracy.
  3. Update Charts During Training using update='append'.

Usage

Start Visdom server:

python -m visdom.server

Then access http://localhost:8097 to view real-time training metrics.

Run the Code

  1. Start Visdom:
python -m visdom.server
  1. Run Training Script:
python train_scratch.py
  1. View Results:

Open: http://localhost:8097

Expected Output

best acc: 0.9098712446351931 best epoch: 6
loaded from ckpt!
test acc: 0.8841201716738197