- Project Overview
- Directory Structure
- Custom Dataset Class: Pokemon
- Main Function & Data Visualization
- Run the Code
PyTorch provides a flexible Dataset and DataLoader interface, which makes creating custom datasets both convenient and efficient. In this article, we will walk through a real-world example to demonstrate how to build a custom image classification dataset for Pokemon images using PyTorch, and visualize the data with additional tools.
Project Overview
We will build a custom dataset class named Pokemon, which will be able to:
- Load Pokemon images from a specified directory.
- Automatically generate or read a CSV file that stores image paths and labels.
- Split the dataset into training, validation, and testing sets.
- Apply a set of data augmentation transforms.
- Visualize the samples using Visdom.
Directory Structure
Before diving into the code, let’s understand the directory structure of the dataset. Suppose our Pokemon dataset is stored in a folder named pokemon:
pokemon/
├── bulbasaur/
│ ├── 0001.png
│ ├── 0002.png
│ └── ...
├── charmander/
│ ├── 0001.jpg
│ ├── 0002.jpg
│ └── ...
├── squirtle/
│ ├── 0001.jpeg
│ ├── 0002.jpeg
│ └── ...
└── ...
Each subfolder represents one Pokemon category, containing image files (supports .png, .jpg, .jpeg formats).
Custom Dataset Class: Pokemon
Initialization Method __init__
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {} # "bulbasaur":0, "charmander":1, ...
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
# Load image paths and labels
self.images, self.labels = self.load_csv('images.csv')
# Split dataset based on mode
if mode == 'train': # 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # test 20%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
Function Explanation:
-
Parameter Description:
root: Path to the dataset root directory.resize: Target size of the resized images.mode: Dataset mode, supportstrain,val, andtest.
-
Category-to-label Mapping:
- Iterates over all subfolders (each representing a Pokemon category) and assigns each one a unique label index.
-
Loading Images and Labels:
- Calls
load_csv('images.csv')to load file paths and labels. If the CSV doesn’t exist, it will be generated automatically.
- Calls
-
Dataset Splits:
- Splits the dataset into training (60%), validation (20%), and testing (20%) according to the
mode.
- Splits the dataset into training (60%), validation (20%), and testing (20%) according to the
Loading or Generating CSV File load_csv
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images), images)
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])
print('writen into csv file:', filename)
# Read CSV file
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
Function Explanation:
-
Check if CSV Exists:
- If
images.csvis missing, the code collects all images from category folders.
- If
-
Shuffle Data:
random.shuffleensures randomness in dataset order.
-
Generate CSV File:
- Writes rows with format
[image_path, label].
- Writes rows with format
-
Load Existing CSV:
- Reads the CSV content into
imagesandlabelslists.
- Reads the CSV content into
Getting Dataset Length __len__
def __len__(self):
return len(self.images)
Returns the total number of samples for the DataLoader.
De-normalization denormalize
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
Function Explanation:
This reverses the standardization applied during preprocessing so that the image can be visually displayed in its original color range.
Getting a Single Sample __getitem__
def __getitem__(self, idx):
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img, label
Function Explanation:
- Load Image: Opens the image as an RGB image.
-
Image Transformations:
- Resize (1.25× scaling)
- RandomRotation (±15°)
- CenterCrop
- Convert to tensor
- Normalize
- Return Tensor and Label
Main Function & Data Visualization
def main():
import visdom
import time
viz = visdom.Visdom()
db = Pokemon('pokemon', 64, 'train')
x, y = next(iter(db))
print('sample:', x.shape, y.shape, y)
viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
for x, y in loader:
viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
- Initialize Visdom: A flexible real-time visualization tool.
- Create Dataset Instance.
- Visualize Single Sample.
- Create DataLoader.
- Visualize Batch Data: Including images and labels.
Run the Code
Complete code:
# (Full code unchanged, omitted here for brevity — same as original)
Make sure required libraries (PyTorch, Torchvision, PIL, Visdom) are installed.
Start Visdom server:
python -m visdom.server
Run the script:
python pokemon.py
Open your browser and visit:
http://localhost:8097
to view the visualized images and labels.

