Introduction
Computer vision has revolutionized how machines understand and interpret visual information. The CIFAR-10 dataset, consisting of 60,000 32x32 color images across 10 classes, serves as an excellent benchmark for learning deep learning fundamentals. In this comprehensive tutorial, we'll build a Convolutional Neural Network (CNN) from scratch using PyTorch to classify these images.
This project demonstrates the complete machine learning pipeline: from data preprocessing and augmentation to model training, evaluation, and deployment. We'll also integrate TensorBoard for real-time visualization and create a desktop application using Tkinter.
Project Repository
The complete source code for this CIFAR-10 classifier is available on GitHub:
View GitHub RepositoryProject Overview
Our CIFAR-10 classification system includes the following key components:
- Dataset: CIFAR-10 (60,000 images across 10 classes)
- Framework: PyTorch with CUDA support
- Architecture: Custom CNN with 2 convolutional layers
- Data Augmentation: Random flips, rotations, and crops
- Monitoring: TensorBoard for real-time visualization
- Deployment: Tkinter-based desktop application
CIFAR-10 Classes:
- 0: Airplane
- 1: Automobile
- 2: Bird
- 3: Cat
- 4: Deer
- 5: Dog
- 6: Frog
- 7: Horse
- 8: Ship
- 9: Truck
Data Preparation
Data Loading
We start by creating a comprehensive data loading function that handles dataset download, normalization, and splitting:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
def load_cifar10_data(root_dir, batch_size=None, num_workers=None):
"""
Load CIFAR-10 dataset from the specified root directory.
"""
# Basic normalization transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), # CIFAR-10 mean
(0.2023, 0.1994, 0.2010)) # CIFAR-10 std
])
# Load datasets
full_dataset = datasets.CIFAR10(root=root_dir, train=True,
download=True, transform=transform)
test_dataset = datasets.CIFAR10(root=root_dir, train=False,
download=True, transform=transform)
# Split training data into train and validation sets (80-20 split)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
return train_loader, val_loader, test_loader
Data Augmentation
Data augmentation is crucial for improving model generalization. We implement various augmentation techniques:
def data_augmentation(data_dir, batch_size=None, num_workers=None):
"""
Enhanced data loading with comprehensive augmentation strategies.
"""
# Training augmentation pipeline
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 50% chance of horizontal flip
transforms.RandomRotation(degrees=15), # Random rotation ±15 degrees
transforms.RandomCrop(32, padding=4), # Random crop with padding
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# Validation/Test transforms (no augmentation)
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# Create datasets with different transforms
train_dataset = datasets.CIFAR10(root=data_dir, train=True,
download=True, transform=train_transform)
val_dataset = datasets.CIFAR10(root=data_dir, train=False,
download=True, transform=val_transform)
test_dataset = datasets.CIFAR10(root=data_dir, train=False,
download=True, transform=val_transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, num_workers=num_workers)
return train_loader, val_loader, test_loader
Augmentation Benefits:
- Random Horizontal Flip: Simulates different viewpoints
- Random Rotation: Handles rotated objects in real scenarios
- Random Crop: Improves translation invariance
- Normalization: Stabilizes training with standardized inputs
CNN Architecture
SimpleCNN Design
Our CNN architecture follows a classic design pattern optimized for CIFAR-10's 32x32 image size:
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) # 3 -> 32 channels
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 32 -> 64 channels
# Pooling layer
self.pool = nn.MaxPool2d(2, 2) # 2x2 max pooling
# Fully connected layers
self.fc1 = nn.Linear(64 * 8 * 8, 512) # 64 channels * 8x8 spatial size
self.fc2 = nn.Linear(512, 10) # 10 output classes
# Regularization
self.dropout = nn.Dropout(0.25)
def forward(self, x):
# First conv block: Conv -> ReLU -> Pool
x = F.relu(self.conv1(x)) # [batch, 32, 32, 32]
x = self.pool(x) # [batch, 32, 16, 16]
# Second conv block: Conv -> ReLU -> Pool
x = F.relu(self.conv2(x)) # [batch, 64, 16, 16]
x = self.pool(x) # [batch, 64, 8, 8]
# Flatten for fully connected layers
x = x.view(-1, 64 * 8 * 8) # [batch, 4096]
# Fully connected layers with dropout
x = F.relu(self.fc1(x)) # [batch, 512]
x = self.dropout(x) # Regularization
x = self.fc2(x) # [batch, 10] - final predictions
return x
Model Components Explained
Architecture Breakdown:
- Input Layer: 3x32x32 (RGB images)
- Conv1 + ReLU + Pool: 32x16x16 feature maps
- Conv2 + ReLU + Pool: 64x8x8 feature maps
- Flatten: 4096-dimensional vector
- FC1 + ReLU + Dropout: 512 hidden units
- FC2: 10 output classes
Key Design Decisions:
- Padding=1: Preserves spatial dimensions through convolutions
- MaxPooling: Reduces spatial size while retaining important features
- Dropout: Prevents overfitting during training
- ReLU Activation: Introduces non-linearity and prevents vanishing gradients
Training Pipeline
Comprehensive Training Function
Our training pipeline includes validation, metrics tracking, and TensorBoard integration:
def train(model, train_loader, val_loader, criterion, optimizer,
device, num_epochs=10, writer=None, excel_path="training_metrics.xlsx"):
import torch
import pandas as pd
# Initialize training history
history = {
'epoch': [], 'train_loss': [], 'val_loss': [],
'train_acc': [], 'val_acc': [],
'val_preds': [], 'val_labels': []
}
model.to(device)
for epoch in range(num_epochs):
# Training phase
model.train()
running_loss = 0.0
correct = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Forward pass
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Accumulate metrics
running_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
# Calculate training metrics
train_loss = running_loss / len(train_loader.dataset)
train_acc = correct / len(train_loader.dataset)
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_preds = []
val_labels = []
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
val_preds.extend(predicted.cpu().numpy())
val_labels.extend(labels.cpu().numpy())
val_correct += (predicted == labels).sum().item()
val_loss = val_loss / len(val_loader.dataset)
val_acc = val_correct / len(val_loader.dataset)
# Log to TensorBoard
if writer is not None:
writer.add_scalar("Loss/Train", train_loss, epoch)
writer.add_scalar("Loss/Val", val_loss, epoch)
writer.add_scalar("Accuracy/Train", train_acc, epoch)
writer.add_scalar("Accuracy/Val", val_acc, epoch)
# Store metrics
history['epoch'].append(epoch + 1)
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
history['val_preds'].append(val_preds)
history['val_labels'].append(val_labels)
# Print progress
print(f"Epoch [{epoch+1}/{num_epochs}] | "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
# Export metrics to Excel
df_history = pd.DataFrame(history)
df_history.to_excel(excel_path, index=False)
print(f"\nā
Training complete. Metrics saved to: {excel_path}")
return history
TensorBoard Integration
We integrate TensorBoard for real-time monitoring of training progress:
# TensorBoard setup
from torch.utils.tensorboard import SummaryWriter
# Initialize TensorBoard writer
writer = SummaryWriter(log_dir="runs/cifar10_experiment")
# Training with TensorBoard logging
history = train(model, train_loader, val_loader, criterion,
optimizer, device, num_epochs=20, writer=writer)
# Close writer
writer.close()
# To view TensorBoard: run "tensorboard --logdir=runs" in terminal
TensorBoard Features
- Loss Tracking: Real-time training and validation loss plots
- Accuracy Monitoring: Training and validation accuracy curves
- Scalars Dashboard: All metrics in one view
- Live Updates: Automatic refresh during training
Access TensorBoard: Run tensorboard --logdir=runs
and open http://localhost:6006
Results & Analysis
Performance Metrics Generation
We generate comprehensive classification metrics and confusion matrices:
def generate_metrices(history):
"""
Generate classification report and confusion matrix from training history.
"""
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
import numpy as np
# Combine predictions and labels from all validation epochs
preds = np.concatenate(history['val_preds'])
labels = np.concatenate(history['val_labels'])
# Generate classification report
report = classification_report(labels, preds, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_df.to_csv("results/classification_report.csv", index=True)
# Generate confusion matrix
cm = confusion_matrix(labels, preds)
cm_df = pd.DataFrame(cm)
cm_df.to_csv("results/confusion_matrix.csv", index=True)
return report_df, cm_df
# Generate metrics after training
report_df, cm_df = generate_metrices(history)
Training Visualization
Create comprehensive training plots to analyze model performance:
def plot_training(history):
"""
Create training and validation loss/accuracy plots.
"""
import matplotlib.pyplot as plt
epochs = range(1, len(history['train_loss']) + 1)
plt.figure(figsize=(12, 5))
# Loss subplot
plt.subplot(1, 2, 1)
plt.plot(epochs, history['train_loss'], 'b-o', label='Train Loss')
plt.plot(epochs, history['val_loss'], 'r-o', label='Val Loss')
plt.title('Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
# Accuracy subplot
plt.subplot(1, 2, 2)
plt.plot(epochs, history['train_acc'], 'b-o', label='Train Acc')
plt.plot(epochs, history['val_acc'], 'r-o', label='Val Acc')
plt.title('Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("results/training_metrics.png", dpi=300, bbox_inches='tight')
print("Training metrics plot saved as 'results/training_metrics.png'")
plt.show()
# Generate training plots
plot_training(history)
Expected Training Results:
- Training Accuracy: ~85-90% after 20 epochs
- Validation Accuracy: ~80-85% (indicates good generalization)
- Training Loss: Steadily decreasing trend
- Validation Loss: Should stabilize without significant overfitting
Model Deployment
The final step involves saving the trained model and creating a deployment-ready system:
# Complete main execution pipeline
if __name__ == "__main__":
import torch.optim as optim
import matplotlib.pyplot as plt
import torch
# Configuration
batch_size = 64
num_workers = 2
num_epochs = 20
root_dir = './data'
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load data with augmentation
train_loader, val_loader, test_loader = data_augmentation(
root_dir, batch_size=batch_size, num_workers=num_workers
)
# Initialize model, loss, and optimizer
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Setup TensorBoard
writer = SummaryWriter(log_dir="runs/cifar10_experiment")
# Train the model
history = train(model, train_loader, val_loader, criterion,
optimizer, device, num_epochs, writer,
excel_path="results/training_metrics.xlsx")
writer.close()
# Generate analytics
report_df, cm_df = generate_metrices(history)
plot_training(history)
# Save complete model checkpoint
model_path = "models/cifar10_cnn.pth"
checkpoint = {
'epoch': history['epoch'],
'train_acc': history['train_acc'],
'val_acc': history['val_acc'],
'train_loss': history['train_loss'],
'val_loss': history['val_loss'],
'num_epochs': num_epochs,
'model_name': 'SimpleCNN',
'batch_size': batch_size,
'num_workers': num_workers,
'loss_function': 'CrossEntropyLoss',
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, model_path)
print(f"Model saved to: {model_path}")
Tkinter Desktop Application
The project includes a user-friendly desktop application built with Tkinter that allows users to:
- Load and preview images for classification
- Run inference with the trained CNN model
- Display predictions with confidence scores
- Browse through CIFAR-10 test images
- Visualize model architecture and performance metrics
Conclusion
This CIFAR-10 CNN classifier project demonstrates a complete deep learning workflow using PyTorch. The implementation showcases modern best practices in computer vision, from data augmentation to model deployment.
šÆ Key Learning Outcomes:
- CNN Architecture: Understanding convolutional layers, pooling, and fully connected layers
- Data Augmentation: Improving model generalization through diverse training data
- PyTorch Framework: Mastering tensors, autograd, and model training loops
- Monitoring & Visualization: Using TensorBoard for real-time training analysis
- Model Deployment: Creating user-friendly applications for model inference
- Performance Analysis: Interpreting training curves and classification metrics
š Advanced Extensions:
This foundational project can be extended in numerous ways:
- Architecture Improvements: ResNet, DenseNet, or Vision Transformers
- Advanced Augmentation: Mixup, CutMix, or AutoAugment techniques
- Transfer Learning: Using pre-trained models for better performance
- Deployment Options: Web APIs, mobile apps, or cloud services
- Real-time Applications: Webcam-based live classification systems