No description has been provided for this image

ResNet to the Rescue¶

Faisal Qureshi
Professor
Faculty of Science
Ontario Tech University
Oshawa ON Canada
http://vclab.science.ontariotechu.ca

Copyright information¶

© Faisal Qureshi

License¶

Creative Commons Licence
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

Tasks¶

The goal of this work is to use a pre-trained ResNet model to perform "custom computer vision tasks."

Specifically, we will use a ResNet pretrained on ImageNet to perform classification on CIFAR10 dataset. Recall that the ImageNet model has 1000 classes; where as, the CIFAR10 has only 10 classes. This means that we cannot use ResNet model out of the box. We will replace the classification head in the pretrained model. We will also have to then retrain the classification head.

  1. Download a pretrained ResNet model from timm
  2. Replace the classifier head, i.e., the "fc" layer with our own.
  3. Setup PyTorch Lightning to train the model. We will only train the "fc" layer.

Go to the end of this notebook to see what you need to submit.

Get a pretrained ResNet model¶

Install timm. Check https://timm.fast.ai/ for information about timm: a deep learning library that contains SOTA computer vision models.

In [58]:
!pip install --quiet timm
!pip install --quiet jupyterlab-widgets
!pip install --quiet ipywidgets
[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: pip install --upgrade pip

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: pip install --upgrade pip

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: pip install --upgrade pip
In [59]:
import timm
import torch

Models available in timm¶

timm stands for Pytorch Image Models and it contains tons of computer vision deep learning models

In [60]:
# set the following to true to find
# the list of pretrained models only
pretrained = True

avail_models = timm.list_models(pretrained=pretrained)
len(avail_models), avail_models[:10]
Out[60]:
(1298,
 ['bat_resnext26ts.ch_in1k',
  'beit_base_patch16_224.in22k_ft_in22k',
  'beit_base_patch16_224.in22k_ft_in22k_in1k',
  'beit_base_patch16_384.in22k_ft_in22k_in1k',
  'beit_large_patch16_224.in22k_ft_in22k',
  'beit_large_patch16_224.in22k_ft_in22k_in1k',
  'beit_large_patch16_384.in22k_ft_in22k_in1k',
  'beit_large_patch16_512.in22k_ft_in22k_in1k',
  'beitv2_base_patch16_224.in1k_ft_in1k',
  'beitv2_base_patch16_224.in1k_ft_in22k'])

Searching models¶

You can also search models as follows

In [61]:
timm.list_models('resnet*', pretrained=pretrained)[:10]
Out[61]:
['resnet10t.c3_in1k',
 'resnet14t.c3_in1k',
 'resnet18.a1_in1k',
 'resnet18.a2_in1k',
 'resnet18.a3_in1k',
 'resnet18.fb_ssl_yfcc100m_ft_in1k',
 'resnet18.fb_swsl_ig1b_ft_in1k',
 'resnet18.gluon_in1k',
 'resnet18.tv_in1k',
 'resnet18d.ra2_in1k']

Getting pretrained ResNet¶

In [62]:
resnet = timm.create_model('resnet34', pretrained=True)
In [63]:
print('Model information:')

# Uncomment the following to print out the model.
# Recall that it is simply a PyTorch model.
#resnet

# Or better acces the config as follows
resnet.default_cfg
Model information:
Out[63]:
{'url': 'https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a1_0-46f8f793.pth',
 'hf_hub_id': 'timm/resnet34.a1_in1k',
 'architecture': 'resnet34',
 'tag': 'a1_in1k',
 'custom_load': False,
 'input_size': (3, 224, 224),
 'test_input_size': (3, 288, 288),
 'fixed_input_size': False,
 'interpolation': 'bicubic',
 'crop_pct': 0.95,
 'test_crop_pct': 1.0,
 'crop_mode': 'center',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'num_classes': 1000,
 'pool_size': (7, 7),
 'first_conv': 'conv1',
 'classifier': 'fc',
 'origin_url': 'https://github.com/huggingface/pytorch-image-models',
 'paper_ids': 'arXiv:2110.00476'}

Using ResNet for our task¶

This model was trained on ImageNet, which as 1000 classes.

In [64]:
resnet.get_classifier()
Out[64]:
Linear(in_features=512, out_features=1000, bias=True)

Swapping out the classifier¶

We want to use this model for the task of classifying Cifar images. Cifar has 10 classes only. This suggests that we cannot use this model out of the box.

We will replace the classification head (the fully-connected classifier layer with 1000 outputs).

In [65]:
resnet_cifar = timm.create_model('resnet34', pretrained=True, num_classes=10)
resnet_cifar.get_classifier()
Out[65]:
Linear(in_features=512, out_features=10, bias=True)

Let's inspect the model

In [66]:
resnet_cifar
Out[66]:
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
  )
  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

Training the classifier¶

Since we have replaced the classification head in the pretrained model with our own classification layer, we need to train the classification head.

PyTorch Lightning: A Framework for Model Training¶

Check out lightning.ai.

In [67]:
!pip install --quiet lightning
!pip install --quiet seaborn
!pip install --quiet tabulate
[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: pip install --upgrade pip

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: pip install --upgrade pip

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: pip install --upgrade pip
In [68]:
import lightning as L
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import transforms

import matplotlib
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import numpy as np
import seaborn as sns
import tabulate
from IPython.display import HTML, display
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from PIL import Image

import os
In [69]:
L.seed_everything(42)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
Seed set to 42

CIFAR10 dataset¶

In [70]:
DATASET_PATH = './data'
CHECKPOINT_PATH = 'saved_models'

Computing mean and standard deviation¶

In [71]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

data_mean = (train_dataset.data / 255.0).mean(axis=(0, 1, 2))
data_std = (train_dataset.data / 255.0).std(axis=(0, 1, 2))
print("Data mean", data_mean)
print("Data std", data_std)
Files already downloaded and verified
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]

Transformations¶

  • Data augmentation for training data
In [72]:
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(data_mean, data_std),
    ]
)

# No data augmentation for testing
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(data_mean, data_std)])

Datasets: train, validation, and test¶

In [73]:
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)

# Note that validation dataset doesn't use augmentations applied to the training dataset
val_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)

train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

test_set = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True) 
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified

Dataloaders¶

Feeding data to our model(s)

In [74]:
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

Displaying some images¶

In [75]:
NUM_IMAGES = 4
images = [train_dataset[idx][0] for idx in range(NUM_IMAGES)]
orig_images = [Image.fromarray(train_dataset.data[idx]) for idx in range(NUM_IMAGES)]
orig_images = [test_transform(img) for img in orig_images]

img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Augmentation examples on CIFAR10")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
No description has been provided for this image

Constructing a LightningModule¶

In [76]:
class CIFARModule(L.LightningModule):
    def __init__(self, model, model_hparams, optimizer_hparams):
        """
        model: PyTorch model that you plan to train
        model_hparams: Model hyperparameters, e.g., dropout, activation functions, etc.
                       Not used in this example.
        optimizer_hparams: Optimizer hyperparameters, e.g., learning rate, etc.
        """
        super().__init__()

        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()  
        self.model = model
        for param in self.model.parameters():
            param.requires_grad = False
        for param in self.model.get_classifier().parameters():
            param.requires_grad = True
        
        self.loss_module = nn.CrossEntropyLoss()
        # Example input for visualizing the graph in Tensorboard
        self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32) 

    def forward(self, imgs):
        return self.model(imgs)

    def configure_optimizers(self):
        #
        # IMPORTANT
        #
        # Note that we are only passing classifiers parameters to the optimizer,
        # since we do not need to update the weights of the model backbone
        #
        optimizer = optim.AdamW(self.model.get_classifier().parameters(), **self.hparams.optimizer_hparams)

        # We will reduce the learning rate by 0.1 after 100 and 150 epochs
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        # Logs the accuracy per epoch to tensorboard (weighted average over batches)
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        self.log("train_loss", loss)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        # By default logs it per epoch (weighted average over batches)
        self.log("val_acc", acc)

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        # By default logs it per epoch (weighted average over batches), and returns it afterwards
        self.log("test_acc", acc)

Training method¶

This methods includes the training, validation, test logic.

In [77]:
def train_model(model, save_name='resnet', **kwargs):

    # Create a PyTorch Lightning trainer with the generation callback
    trainer = L.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),  # Where to save models
        # We run on a single GPU (if possible)
        accelerator="auto",
        devices=1,
        # How many epochs to train for if no patience is set
        max_epochs=2,
        callbacks=[
            ModelCheckpoint(
                save_weights_only=True, mode="max", monitor="val_acc"
            ),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
            LearningRateMonitor("epoch"),
        ],  # Log learning rate every epoch
    )  # In case your notebook crashes due to the progress bar, consider increasing the refresh rate
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        lightning_model = CIFARModule.load_from_checkpoint(pretrained_filename)
    else:
        print("No checkpoint found")
        L.seed_everything(42)  # To be reproducible
        lightning_model = CIFARModule(model, **kwargs)
        trainer.fit(lightning_model, train_loader, val_loader)
        lightning_model = CIFARModule.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path
        )  # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(lightning_model, dataloaders=val_loader, verbose=False)
    test_result = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result

Now train¶

In [ ]:
model_hparams = {}
optimizer_hparams = {"lr": 1e-3, "weight_decay": 1e-4}

train_model(resnet_cifar, model_hparams=model_hparams, optimizer_hparams=optimizer_hparams)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Seed set to 42

  | Name        | Type             | Params | In sizes       | Out sizes
------------------------------------------------------------------------------
0 | model       | ResNet           | 21.3 M | [1, 3, 32, 32] | [1, 10]  
1 | loss_module | CrossEntropyLoss | 0      | ?              | ?        
------------------------------------------------------------------------------
5.1 K     Trainable params
21.3 M    Non-trainable params
21.3 M    Total params
85.159    Total estimated model params size (MB)
No checkpoint found
Epoch 0: 100%|███████████████████████| 351/351 [09:07<00:00,  0.64it/s, v_num=4]
Validation: |                                             | 0/? [00:00<?, ?it/s]
Validation:   0%|                                        | 0/40 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                           | 0/40 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▍                  | 1/40 [00:00<00:00, 50.91it/s]
Validation DataLoader 0:   5%|▉                  | 2/40 [00:00<00:00, 56.98it/s]
Validation DataLoader 0:   8%|█▍                 | 3/40 [00:00<00:00, 65.86it/s]
Validation DataLoader 0:  10%|█▉                 | 4/40 [00:00<00:00, 71.70it/s]
Validation DataLoader 0:  12%|██▍                | 5/40 [00:00<00:00, 76.00it/s]
Validation DataLoader 0:  15%|██▊                | 6/40 [00:00<00:00, 79.21it/s]
Validation DataLoader 0:  18%|███▎               | 7/40 [00:00<00:00, 81.60it/s]
Validation DataLoader 0:  20%|███▊               | 8/40 [00:00<00:00, 83.25it/s]
Validation DataLoader 0:  22%|████▎              | 9/40 [00:00<00:00, 84.68it/s]
Validation DataLoader 0:  25%|████▌             | 10/40 [00:00<00:00, 86.26it/s]
Validation DataLoader 0:  28%|████▉             | 11/40 [00:00<00:00, 87.41it/s]
Validation DataLoader 0:  30%|█████▍            | 12/40 [00:00<00:00, 87.94it/s]
Validation DataLoader 0:  32%|█████▊            | 13/40 [00:00<00:00, 88.73it/s]
Validation DataLoader 0:  35%|██████▎           | 14/40 [00:00<00:00, 89.38it/s]
Validation DataLoader 0:  38%|██████▊           | 15/40 [00:00<00:00, 90.22it/s]
Validation DataLoader 0:  40%|███████▏          | 16/40 [00:00<00:00, 90.74it/s]
Validation DataLoader 0:  42%|███████▋          | 17/40 [00:00<00:00, 91.14it/s]
Validation DataLoader 0:  45%|████████          | 18/40 [00:00<00:00, 91.36it/s]
Validation DataLoader 0:  48%|████████▌         | 19/40 [00:00<00:00, 91.80it/s]
Validation DataLoader 0:  50%|█████████         | 20/40 [00:00<00:00, 92.03it/s]
Validation DataLoader 0:  52%|█████████▍        | 21/40 [00:00<00:00, 92.30it/s]
Validation DataLoader 0:  55%|█████████▉        | 22/40 [00:00<00:00, 92.59it/s]
Validation DataLoader 0:  57%|██████████▎       | 23/40 [00:00<00:00, 92.86it/s]
Validation DataLoader 0:  60%|██████████▊       | 24/40 [00:00<00:00, 92.93it/s]
Validation DataLoader 0:  62%|███████████▎      | 25/40 [00:00<00:00, 93.20it/s]
Validation DataLoader 0:  65%|███████████▋      | 26/40 [00:00<00:00, 93.36it/s]
Validation DataLoader 0:  68%|████████████▏     | 27/40 [00:00<00:00, 93.51it/s]
Validation DataLoader 0:  70%|████████████▌     | 28/40 [00:00<00:00, 93.72it/s]
Validation DataLoader 0:  72%|█████████████     | 29/40 [00:00<00:00, 93.85it/s]
Validation DataLoader 0:  75%|█████████████▌    | 30/40 [00:00<00:00, 93.95it/s]
Validation DataLoader 0:  78%|█████████████▉    | 31/40 [00:00<00:00, 94.11it/s]
Validation DataLoader 0:  80%|██████████████▍   | 32/40 [00:00<00:00, 94.28it/s]
Validation DataLoader 0:  82%|██████████████▊   | 33/40 [00:00<00:00, 94.46it/s]
Validation DataLoader 0:  85%|███████████████▎  | 34/40 [00:00<00:00, 94.61it/s]
Validation DataLoader 0:  88%|███████████████▊  | 35/40 [00:00<00:00, 94.75it/s]
Validation DataLoader 0:  90%|████████████████▏ | 36/40 [00:00<00:00, 94.85it/s]
Validation DataLoader 0:  92%|████████████████▋ | 37/40 [00:00<00:00, 94.97it/s]
Validation DataLoader 0:  95%|█████████████████ | 38/40 [00:00<00:00, 95.14it/s]
Validation DataLoader 0:  98%|█████████████████▌| 39/40 [00:00<00:00, 95.22it/s]
Validation DataLoader 0: 100%|██████████████████| 40/40 [00:00<00:00, 94.95it/s]
Epoch 1: 100%|███████████████████████| 351/351 [10:22<00:00,  0.56it/s, v_num=4]
Validation: |                                             | 0/? [00:00<?, ?it/s]
Validation:   0%|                                        | 0/40 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                           | 0/40 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▍                  | 1/40 [00:00<00:00, 49.78it/s]
Validation DataLoader 0:   5%|▉                  | 2/40 [00:00<00:00, 55.79it/s]
Validation DataLoader 0:   8%|█▍                 | 3/40 [00:00<00:00, 64.36it/s]
Validation DataLoader 0:  10%|█▉                 | 4/40 [00:00<00:00, 70.41it/s]
Validation DataLoader 0:  12%|██▍                | 5/40 [00:00<00:00, 73.69it/s]
Validation DataLoader 0:  15%|██▊                | 6/40 [00:00<00:00, 75.85it/s]
Validation DataLoader 0:  18%|███▎               | 7/40 [00:00<00:00, 78.65it/s]
Validation DataLoader 0:  20%|███▊               | 8/40 [00:00<00:00, 80.49it/s]
Validation DataLoader 0:  22%|████▎              | 9/40 [00:00<00:00, 82.08it/s]
Validation DataLoader 0:  25%|████▌             | 10/40 [00:00<00:00, 83.44it/s]
Validation DataLoader 0:  28%|████▉             | 11/40 [00:00<00:00, 84.58it/s]
Validation DataLoader 0:  30%|█████▍            | 12/40 [00:00<00:00, 85.49it/s]
Validation DataLoader 0:  32%|█████▊            | 13/40 [00:00<00:00, 86.39it/s]
Validation DataLoader 0:  35%|██████▎           | 14/40 [00:00<00:00, 87.06it/s]
Validation DataLoader 0:  38%|██████▊           | 15/40 [00:00<00:00, 87.88it/s]
Validation DataLoader 0:  40%|███████▏          | 16/40 [00:00<00:00, 87.88it/s]
Validation DataLoader 0:  42%|███████▋          | 17/40 [00:00<00:00, 88.64it/s]
Validation DataLoader 0:  45%|████████          | 18/40 [00:00<00:00, 89.10it/s]
Validation DataLoader 0:  48%|████████▌         | 19/40 [00:00<00:00, 89.54it/s]
Validation DataLoader 0:  50%|█████████         | 20/40 [00:00<00:00, 89.78it/s]
Validation DataLoader 0:  52%|█████████▍        | 21/40 [00:00<00:00, 90.14it/s]
Validation DataLoader 0:  55%|█████████▉        | 22/40 [00:00<00:00, 90.57it/s]
Validation DataLoader 0:  57%|██████████▎       | 23/40 [00:00<00:00, 90.87it/s]
Validation DataLoader 0:  60%|██████████▊       | 24/40 [00:00<00:00, 91.04it/s]
Validation DataLoader 0:  62%|███████████▎      | 25/40 [00:00<00:00, 91.22it/s]
Validation DataLoader 0:  65%|███████████▋      | 26/40 [00:00<00:00, 91.53it/s]
Validation DataLoader 0:  68%|████████████▏     | 27/40 [00:00<00:00, 91.65it/s]
Validation DataLoader 0:  70%|████████████▌     | 28/40 [00:00<00:00, 91.88it/s]
Validation DataLoader 0:  72%|█████████████     | 29/40 [00:00<00:00, 92.15it/s]
Validation DataLoader 0:  75%|█████████████▌    | 30/40 [00:00<00:00, 92.28it/s]
Validation DataLoader 0:  78%|█████████████▉    | 31/40 [00:00<00:00, 92.43it/s]
Validation DataLoader 0:  80%|██████████████▍   | 32/40 [00:00<00:00, 92.58it/s]
Validation DataLoader 0:  82%|██████████████▊   | 33/40 [00:00<00:00, 92.76it/s]
Validation DataLoader 0:  85%|███████████████▎  | 34/40 [00:00<00:00, 92.91it/s]
Validation DataLoader 0:  88%|███████████████▊  | 35/40 [00:00<00:00, 93.06it/s]
Validation DataLoader 0:  90%|████████████████▏ | 36/40 [00:00<00:00, 93.27it/s]
Validation DataLoader 0:  92%|████████████████▋ | 37/40 [00:00<00:00, 93.42it/s]
Validation DataLoader 0:  95%|█████████████████ | 38/40 [00:00<00:00, 93.47it/s]
Validation DataLoader 0:  98%|█████████████████▌| 39/40 [00:00<00:00, 93.77it/s]
Validation DataLoader 0: 100%|██████████████████| 40/40 [00:00<00:00, 92.43it/s]
Epoch 2: 100%|█████████████████████| 351/351 [1:11:07<00:00,  0.08it/s, v_num=4]
Validation: |                                             | 0/? [00:00<?, ?it/s]
Validation:   0%|                                        | 0/40 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                           | 0/40 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▍                  | 1/40 [00:00<00:00, 48.76it/s]
Validation DataLoader 0:   5%|▉                  | 2/40 [00:00<00:00, 53.88it/s]
Validation DataLoader 0:   8%|█▍                 | 3/40 [00:00<00:00, 62.70it/s]
Validation DataLoader 0:  10%|█▉                 | 4/40 [00:00<00:00, 68.84it/s]
Validation DataLoader 0:  12%|██▍                | 5/40 [00:00<00:00, 73.17it/s]
Validation DataLoader 0:  15%|██▊                | 6/40 [00:00<00:00, 76.17it/s]
Validation DataLoader 0:  18%|███▎               | 7/40 [00:00<00:00, 78.55it/s]
Validation DataLoader 0:  20%|███▊               | 8/40 [00:00<00:00, 80.68it/s]
Validation DataLoader 0:  22%|████▎              | 9/40 [00:00<00:00, 81.66it/s]
Validation DataLoader 0:  25%|████▌             | 10/40 [00:00<00:00, 82.67it/s]
Validation DataLoader 0:  28%|████▉             | 11/40 [00:00<00:00, 83.73it/s]
Validation DataLoader 0:  30%|█████▍            | 12/40 [00:00<00:00, 84.44it/s]
Validation DataLoader 0:  32%|█████▊            | 13/40 [00:00<00:00, 85.22it/s]
Validation DataLoader 0:  35%|██████▎           | 14/40 [00:00<00:00, 85.30it/s]
Validation DataLoader 0:  38%|██████▊           | 15/40 [00:00<00:00, 85.38it/s]
Validation DataLoader 0:  40%|███████▏          | 16/40 [00:00<00:00, 85.64it/s]
Validation DataLoader 0:  42%|███████▋          | 17/40 [00:00<00:00, 85.80it/s]
Validation DataLoader 0:  45%|████████          | 18/40 [00:00<00:00, 85.69it/s]
Validation DataLoader 0:  48%|████████▌         | 19/40 [00:00<00:00, 85.56it/s]
Validation DataLoader 0:  50%|█████████         | 20/40 [00:00<00:00, 85.55it/s]
Validation DataLoader 0:  52%|█████████▍        | 21/40 [00:00<00:00, 85.61it/s]
Validation DataLoader 0:  55%|█████████▉        | 22/40 [00:00<00:00, 85.18it/s]
Validation DataLoader 0:  57%|██████████▎       | 23/40 [00:00<00:00, 84.76it/s]
Validation DataLoader 0:  60%|██████████▊       | 24/40 [00:00<00:00, 84.50it/s]
Validation DataLoader 0:  62%|███████████▎      | 25/40 [00:00<00:00, 83.98it/s]
Validation DataLoader 0:  65%|███████████▋      | 26/40 [00:00<00:00, 83.51it/s]
Validation DataLoader 0:  68%|████████████▏     | 27/40 [00:00<00:00, 82.85it/s]
Validation DataLoader 0:  70%|████████████▌     | 28/40 [00:00<00:00, 82.40it/s]
Validation DataLoader 0:  72%|█████████████     | 29/40 [00:00<00:00, 81.80it/s]
Validation DataLoader 0:  75%|█████████████▌    | 30/40 [00:00<00:00, 81.03it/s]
Validation DataLoader 0:  78%|█████████████▉    | 31/40 [00:00<00:00, 80.26it/s]
Validation DataLoader 0:  80%|██████████████▍   | 32/40 [00:00<00:00, 79.51it/s]
Validation DataLoader 0:  82%|██████████████▊   | 33/40 [00:00<00:00, 78.72it/s]
Validation DataLoader 0:  85%|███████████████▎  | 34/40 [00:00<00:00, 77.71it/s]
Validation DataLoader 0:  88%|███████████████▊  | 35/40 [00:00<00:00, 76.64it/s]
Validation DataLoader 0:  90%|████████████████▏ | 36/40 [00:00<00:00, 75.42it/s]
Validation DataLoader 0:  92%|████████████████▋ | 37/40 [00:00<00:00, 74.01it/s]
Validation DataLoader 0:  95%|█████████████████ | 38/40 [00:00<00:00, 72.26it/s]
Validation DataLoader 0:  98%|█████████████████▌| 39/40 [00:00<00:00, 70.52it/s]
Validation DataLoader 0: 100%|██████████████████| 40/40 [00:00<00:00, 68.27it/s]
Epoch 3:  28%|██████▋                 | 97/351 [12:43<33:19,  0.13it/s, v_num=4]

Complete the following tasks

1. Modify the train_model function to take the number of epochs as an argument.

2. Train the model for at least 20 epochs and plot the validation and test accuracies.

Look at saved_models/resnet/lightning_logs for the this information.

3. Modify the train_model function to load the model from checkpoint

4. Write a function that uses the model trained in step 2 to perform inference in a list of images. The function prints the list of images alongside their true and predicted labels. We can use this function as follows:

image_files = ["1.png", "2.png"]
classify(image_files)

The output of this program will look similar to

Classification results:
1.png dog dog
2.png cat dog
 

5. (Bonus) the current classification head consists of single fully-connected layer. Let's replace it with a neural network with one hidden layer. This can be achieve by replacing the "fc" layer with a sequence of layers as seen in the following code snippet.

model.fc = nn.Sequential(
    nn.BatchNorm1d(num_in_features),
    nn.Linear(in_features=num_in_features, out_features=1024, bias=True),
    nn.ReLU(),
    nn.BatchNorm1d(1024),
    nn.Dropout(0.5)
    nn.Linear(in_features=1024, out_features=10, bias=True)
)

where num_in_features simply reflect feature size for the ResNet encoder. You can find this information as follows

num_in_features = resnet_cifar.get_classifier().in_features

Jupyter notebook

Source notebook is available here.

GPU resources¶

(Experimental) You can use the GPU resources available at https://hubdev.science.ontariotechu.ca/ to complete your lab.

No description has been provided for this image