ResNet to the Rescue¶
Faisal Qureshi
Professor
Faculty of Science
Ontario Tech University
Oshawa ON Canada
http://vclab.science.ontariotechu.ca
Copyright information¶
© Faisal Qureshi
License¶
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
Tasks¶
The goal of this work is to use a pre-trained ResNet model to perform "custom computer vision tasks."
Specifically, we will use a ResNet pretrained on ImageNet to perform classification on CIFAR10 dataset. Recall that the ImageNet model has 1000 classes; where as, the CIFAR10 has only 10 classes. This means that we cannot use ResNet model out of the box. We will replace the classification head in the pretrained model. We will also have to then retrain the classification head.
- Download a pretrained ResNet model from
timm
- Replace the classifier head, i.e., the "fc" layer with our own.
- Setup PyTorch Lightning to train the model. We will only train the "fc" layer.
Go to the end of this notebook to see what you need to submit.
Get a pretrained ResNet model¶
Install timm
. Check https://timm.fast.ai/ for information about timm
: a deep learning library that contains SOTA computer vision models.
!pip install --quiet timm
!pip install --quiet jupyterlab-widgets
!pip install --quiet ipywidgets
[notice] A new release of pip is available: 23.3.1 -> 24.0 [notice] To update, run: pip install --upgrade pip [notice] A new release of pip is available: 23.3.1 -> 24.0 [notice] To update, run: pip install --upgrade pip [notice] A new release of pip is available: 23.3.1 -> 24.0 [notice] To update, run: pip install --upgrade pip
import timm
import torch
Models available in timm
¶
timm
stands for Pytorch Image Models and it contains tons of computer vision deep learning models
# set the following to true to find
# the list of pretrained models only
pretrained = True
avail_models = timm.list_models(pretrained=pretrained)
len(avail_models), avail_models[:10]
(1298, ['bat_resnext26ts.ch_in1k', 'beit_base_patch16_224.in22k_ft_in22k', 'beit_base_patch16_224.in22k_ft_in22k_in1k', 'beit_base_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_224.in22k_ft_in22k', 'beit_large_patch16_224.in22k_ft_in22k_in1k', 'beit_large_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_512.in22k_ft_in22k_in1k', 'beitv2_base_patch16_224.in1k_ft_in1k', 'beitv2_base_patch16_224.in1k_ft_in22k'])
Searching models¶
You can also search models as follows
timm.list_models('resnet*', pretrained=pretrained)[:10]
['resnet10t.c3_in1k', 'resnet14t.c3_in1k', 'resnet18.a1_in1k', 'resnet18.a2_in1k', 'resnet18.a3_in1k', 'resnet18.fb_ssl_yfcc100m_ft_in1k', 'resnet18.fb_swsl_ig1b_ft_in1k', 'resnet18.gluon_in1k', 'resnet18.tv_in1k', 'resnet18d.ra2_in1k']
Getting pretrained ResNet¶
resnet = timm.create_model('resnet34', pretrained=True)
print('Model information:')
# Uncomment the following to print out the model.
# Recall that it is simply a PyTorch model.
#resnet
# Or better acces the config as follows
resnet.default_cfg
Model information:
{'url': 'https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a1_0-46f8f793.pth', 'hf_hub_id': 'timm/resnet34.a1_in1k', 'architecture': 'resnet34', 'tag': 'a1_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'}
Using ResNet for our task¶
This model was trained on ImageNet, which as 1000 classes.
resnet.get_classifier()
Linear(in_features=512, out_features=1000, bias=True)
Swapping out the classifier¶
We want to use this model for the task of classifying Cifar images. Cifar has 10 classes only. This suggests that we cannot use this model out of the box.
We will replace the classification head (the fully-connected classifier layer with 1000 outputs).
resnet_cifar = timm.create_model('resnet34', pretrained=True, num_classes=10)
resnet_cifar.get_classifier()
Linear(in_features=512, out_features=10, bias=True)
Let's inspect the model
resnet_cifar
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (2): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (2): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (3): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (2): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (3): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (4): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (5): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) (2): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (drop_block): Identity() (act1): ReLU(inplace=True) (aa): Identity() (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) ) ) (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1)) (fc): Linear(in_features=512, out_features=10, bias=True) )
Training the classifier¶
Since we have replaced the classification head in the pretrained model with our own classification layer, we need to train the classification head.
PyTorch Lightning: A Framework for Model Training¶
Check out lightning.ai.
!pip install --quiet lightning
!pip install --quiet seaborn
!pip install --quiet tabulate
[notice] A new release of pip is available: 23.3.1 -> 24.0 [notice] To update, run: pip install --upgrade pip [notice] A new release of pip is available: 23.3.1 -> 24.0 [notice] To update, run: pip install --upgrade pip [notice] A new release of pip is available: 23.3.1 -> 24.0 [notice] To update, run: pip install --upgrade pip
import lightning as L
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import transforms
import matplotlib
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import numpy as np
import seaborn as sns
import tabulate
from IPython.display import HTML, display
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from PIL import Image
import os
L.seed_everything(42)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
Seed set to 42
CIFAR10 dataset¶
DATASET_PATH = './data'
CHECKPOINT_PATH = 'saved_models'
Computing mean and standard deviation¶
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
data_mean = (train_dataset.data / 255.0).mean(axis=(0, 1, 2))
data_std = (train_dataset.data / 255.0).std(axis=(0, 1, 2))
print("Data mean", data_mean)
print("Data std", data_std)
Files already downloaded and verified Data mean [0.49139968 0.48215841 0.44653091] Data std [0.24703223 0.24348513 0.26158784]
Transformations¶
- Data augmentation for training data
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std),
]
)
# No data augmentation for testing
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(data_mean, data_std)])
Datasets: train, validation, and test¶
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
# Note that validation dataset doesn't use augmentations applied to the training dataset
val_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])
test_set = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)
Files already downloaded and verified Files already downloaded and verified Files already downloaded and verified
Dataloaders¶
Feeding data to our model(s)
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
Displaying some images¶
NUM_IMAGES = 4
images = [train_dataset[idx][0] for idx in range(NUM_IMAGES)]
orig_images = [Image.fromarray(train_dataset.data[idx]) for idx in range(NUM_IMAGES)]
orig_images = [test_transform(img) for img in orig_images]
img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)
plt.figure(figsize=(8, 8))
plt.title("Augmentation examples on CIFAR10")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()
Constructing a LightningModule
¶
class CIFARModule(L.LightningModule):
def __init__(self, model, model_hparams, optimizer_hparams):
"""
model: PyTorch model that you plan to train
model_hparams: Model hyperparameters, e.g., dropout, activation functions, etc.
Not used in this example.
optimizer_hparams: Optimizer hyperparameters, e.g., learning rate, etc.
"""
super().__init__()
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
self.save_hyperparameters()
self.model = model
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
self.loss_module = nn.CrossEntropyLoss()
# Example input for visualizing the graph in Tensorboard
self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)
def forward(self, imgs):
return self.model(imgs)
def configure_optimizers(self):
#
# IMPORTANT
#
# Note that we are only passing classifiers parameters to the optimizer,
# since we do not need to update the weights of the model backbone
#
optimizer = optim.AdamW(self.model.get_classifier().parameters(), **self.hparams.optimizer_hparams)
# We will reduce the learning rate by 0.1 after 100 and 150 epochs
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
# "batch" is the output of the training data loader.
imgs, labels = batch
preds = self.model(imgs)
loss = self.loss_module(preds, labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
# Logs the accuracy per epoch to tensorboard (weighted average over batches)
self.log("train_acc", acc, on_step=False, on_epoch=True)
self.log("train_loss", loss)
return loss # Return tensor to call ".backward" on
def validation_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs).argmax(dim=-1)
acc = (labels == preds).float().mean()
# By default logs it per epoch (weighted average over batches)
self.log("val_acc", acc)
def test_step(self, batch, batch_idx):
imgs, labels = batch
preds = self.model(imgs).argmax(dim=-1)
acc = (labels == preds).float().mean()
# By default logs it per epoch (weighted average over batches), and returns it afterwards
self.log("test_acc", acc)
Training method¶
This methods includes the training, validation, test logic.
def train_model(model, save_name='resnet', **kwargs):
# Create a PyTorch Lightning trainer with the generation callback
trainer = L.Trainer(
default_root_dir=os.path.join(CHECKPOINT_PATH, save_name), # Where to save models
# We run on a single GPU (if possible)
accelerator="auto",
devices=1,
# How many epochs to train for if no patience is set
max_epochs=2,
callbacks=[
ModelCheckpoint(
save_weights_only=True, mode="max", monitor="val_acc"
), # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
LearningRateMonitor("epoch"),
], # Log learning rate every epoch
) # In case your notebook crashes due to the progress bar, consider increasing the refresh rate
trainer.logger._log_graph = True # If True, we plot the computation graph in tensorboard
trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
# Check whether pretrained model exists. If yes, load it and skip training
pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
if os.path.isfile(pretrained_filename):
print(f"Found pretrained model at {pretrained_filename}, loading...")
# Automatically loads the model with the saved hyperparameters
lightning_model = CIFARModule.load_from_checkpoint(pretrained_filename)
else:
print("No checkpoint found")
L.seed_everything(42) # To be reproducible
lightning_model = CIFARModule(model, **kwargs)
trainer.fit(lightning_model, train_loader, val_loader)
lightning_model = CIFARModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path
) # Load best checkpoint after training
# Test best model on validation and test set
val_result = trainer.test(lightning_model, dataloaders=val_loader, verbose=False)
test_result = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)
result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}
return model, result
Now train¶
model_hparams = {}
optimizer_hparams = {"lr": 1e-3, "weight_decay": 1e-4}
train_model(resnet_cifar, model_hparams=model_hparams, optimizer_hparams=optimizer_hparams)
GPU available: True (mps), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Seed set to 42 | Name | Type | Params | In sizes | Out sizes ------------------------------------------------------------------------------ 0 | model | ResNet | 21.3 M | [1, 3, 32, 32] | [1, 10] 1 | loss_module | CrossEntropyLoss | 0 | ? | ? ------------------------------------------------------------------------------ 5.1 K Trainable params 21.3 M Non-trainable params 21.3 M Total params 85.159 Total estimated model params size (MB)
No checkpoint found Epoch 0: 100%|███████████████████████| 351/351 [09:07<00:00, 0.64it/s, v_num=4] Validation: | | 0/? [00:00<?, ?it/s] Validation: 0%| | 0/40 [00:00<?, ?it/s] Validation DataLoader 0: 0%| | 0/40 [00:00<?, ?it/s] Validation DataLoader 0: 2%|▍ | 1/40 [00:00<00:00, 50.91it/s] Validation DataLoader 0: 5%|▉ | 2/40 [00:00<00:00, 56.98it/s] Validation DataLoader 0: 8%|█▍ | 3/40 [00:00<00:00, 65.86it/s] Validation DataLoader 0: 10%|█▉ | 4/40 [00:00<00:00, 71.70it/s] Validation DataLoader 0: 12%|██▍ | 5/40 [00:00<00:00, 76.00it/s] Validation DataLoader 0: 15%|██▊ | 6/40 [00:00<00:00, 79.21it/s] Validation DataLoader 0: 18%|███▎ | 7/40 [00:00<00:00, 81.60it/s] Validation DataLoader 0: 20%|███▊ | 8/40 [00:00<00:00, 83.25it/s] Validation DataLoader 0: 22%|████▎ | 9/40 [00:00<00:00, 84.68it/s] Validation DataLoader 0: 25%|████▌ | 10/40 [00:00<00:00, 86.26it/s] Validation DataLoader 0: 28%|████▉ | 11/40 [00:00<00:00, 87.41it/s] Validation DataLoader 0: 30%|█████▍ | 12/40 [00:00<00:00, 87.94it/s] Validation DataLoader 0: 32%|█████▊ | 13/40 [00:00<00:00, 88.73it/s] Validation DataLoader 0: 35%|██████▎ | 14/40 [00:00<00:00, 89.38it/s] Validation DataLoader 0: 38%|██████▊ | 15/40 [00:00<00:00, 90.22it/s] Validation DataLoader 0: 40%|███████▏ | 16/40 [00:00<00:00, 90.74it/s] Validation DataLoader 0: 42%|███████▋ | 17/40 [00:00<00:00, 91.14it/s] Validation DataLoader 0: 45%|████████ | 18/40 [00:00<00:00, 91.36it/s] Validation DataLoader 0: 48%|████████▌ | 19/40 [00:00<00:00, 91.80it/s] Validation DataLoader 0: 50%|█████████ | 20/40 [00:00<00:00, 92.03it/s] Validation DataLoader 0: 52%|█████████▍ | 21/40 [00:00<00:00, 92.30it/s] Validation DataLoader 0: 55%|█████████▉ | 22/40 [00:00<00:00, 92.59it/s] Validation DataLoader 0: 57%|██████████▎ | 23/40 [00:00<00:00, 92.86it/s] Validation DataLoader 0: 60%|██████████▊ | 24/40 [00:00<00:00, 92.93it/s] Validation DataLoader 0: 62%|███████████▎ | 25/40 [00:00<00:00, 93.20it/s] Validation DataLoader 0: 65%|███████████▋ | 26/40 [00:00<00:00, 93.36it/s] Validation DataLoader 0: 68%|████████████▏ | 27/40 [00:00<00:00, 93.51it/s] Validation DataLoader 0: 70%|████████████▌ | 28/40 [00:00<00:00, 93.72it/s] Validation DataLoader 0: 72%|█████████████ | 29/40 [00:00<00:00, 93.85it/s] Validation DataLoader 0: 75%|█████████████▌ | 30/40 [00:00<00:00, 93.95it/s] Validation DataLoader 0: 78%|█████████████▉ | 31/40 [00:00<00:00, 94.11it/s] Validation DataLoader 0: 80%|██████████████▍ | 32/40 [00:00<00:00, 94.28it/s] Validation DataLoader 0: 82%|██████████████▊ | 33/40 [00:00<00:00, 94.46it/s] Validation DataLoader 0: 85%|███████████████▎ | 34/40 [00:00<00:00, 94.61it/s] Validation DataLoader 0: 88%|███████████████▊ | 35/40 [00:00<00:00, 94.75it/s] Validation DataLoader 0: 90%|████████████████▏ | 36/40 [00:00<00:00, 94.85it/s] Validation DataLoader 0: 92%|████████████████▋ | 37/40 [00:00<00:00, 94.97it/s] Validation DataLoader 0: 95%|█████████████████ | 38/40 [00:00<00:00, 95.14it/s] Validation DataLoader 0: 98%|█████████████████▌| 39/40 [00:00<00:00, 95.22it/s] Validation DataLoader 0: 100%|██████████████████| 40/40 [00:00<00:00, 94.95it/s] Epoch 1: 100%|███████████████████████| 351/351 [10:22<00:00, 0.56it/s, v_num=4] Validation: | | 0/? [00:00<?, ?it/s] Validation: 0%| | 0/40 [00:00<?, ?it/s] Validation DataLoader 0: 0%| | 0/40 [00:00<?, ?it/s] Validation DataLoader 0: 2%|▍ | 1/40 [00:00<00:00, 49.78it/s] Validation DataLoader 0: 5%|▉ | 2/40 [00:00<00:00, 55.79it/s] Validation DataLoader 0: 8%|█▍ | 3/40 [00:00<00:00, 64.36it/s] Validation DataLoader 0: 10%|█▉ | 4/40 [00:00<00:00, 70.41it/s] Validation DataLoader 0: 12%|██▍ | 5/40 [00:00<00:00, 73.69it/s] Validation DataLoader 0: 15%|██▊ | 6/40 [00:00<00:00, 75.85it/s] Validation DataLoader 0: 18%|███▎ | 7/40 [00:00<00:00, 78.65it/s] Validation DataLoader 0: 20%|███▊ | 8/40 [00:00<00:00, 80.49it/s] Validation DataLoader 0: 22%|████▎ | 9/40 [00:00<00:00, 82.08it/s] Validation DataLoader 0: 25%|████▌ | 10/40 [00:00<00:00, 83.44it/s] Validation DataLoader 0: 28%|████▉ | 11/40 [00:00<00:00, 84.58it/s] Validation DataLoader 0: 30%|█████▍ | 12/40 [00:00<00:00, 85.49it/s] Validation DataLoader 0: 32%|█████▊ | 13/40 [00:00<00:00, 86.39it/s] Validation DataLoader 0: 35%|██████▎ | 14/40 [00:00<00:00, 87.06it/s] Validation DataLoader 0: 38%|██████▊ | 15/40 [00:00<00:00, 87.88it/s] Validation DataLoader 0: 40%|███████▏ | 16/40 [00:00<00:00, 87.88it/s] Validation DataLoader 0: 42%|███████▋ | 17/40 [00:00<00:00, 88.64it/s] Validation DataLoader 0: 45%|████████ | 18/40 [00:00<00:00, 89.10it/s] Validation DataLoader 0: 48%|████████▌ | 19/40 [00:00<00:00, 89.54it/s] Validation DataLoader 0: 50%|█████████ | 20/40 [00:00<00:00, 89.78it/s] Validation DataLoader 0: 52%|█████████▍ | 21/40 [00:00<00:00, 90.14it/s] Validation DataLoader 0: 55%|█████████▉ | 22/40 [00:00<00:00, 90.57it/s] Validation DataLoader 0: 57%|██████████▎ | 23/40 [00:00<00:00, 90.87it/s] Validation DataLoader 0: 60%|██████████▊ | 24/40 [00:00<00:00, 91.04it/s] Validation DataLoader 0: 62%|███████████▎ | 25/40 [00:00<00:00, 91.22it/s] Validation DataLoader 0: 65%|███████████▋ | 26/40 [00:00<00:00, 91.53it/s] Validation DataLoader 0: 68%|████████████▏ | 27/40 [00:00<00:00, 91.65it/s] Validation DataLoader 0: 70%|████████████▌ | 28/40 [00:00<00:00, 91.88it/s] Validation DataLoader 0: 72%|█████████████ | 29/40 [00:00<00:00, 92.15it/s] Validation DataLoader 0: 75%|█████████████▌ | 30/40 [00:00<00:00, 92.28it/s] Validation DataLoader 0: 78%|█████████████▉ | 31/40 [00:00<00:00, 92.43it/s] Validation DataLoader 0: 80%|██████████████▍ | 32/40 [00:00<00:00, 92.58it/s] Validation DataLoader 0: 82%|██████████████▊ | 33/40 [00:00<00:00, 92.76it/s] Validation DataLoader 0: 85%|███████████████▎ | 34/40 [00:00<00:00, 92.91it/s] Validation DataLoader 0: 88%|███████████████▊ | 35/40 [00:00<00:00, 93.06it/s] Validation DataLoader 0: 90%|████████████████▏ | 36/40 [00:00<00:00, 93.27it/s] Validation DataLoader 0: 92%|████████████████▋ | 37/40 [00:00<00:00, 93.42it/s] Validation DataLoader 0: 95%|█████████████████ | 38/40 [00:00<00:00, 93.47it/s] Validation DataLoader 0: 98%|█████████████████▌| 39/40 [00:00<00:00, 93.77it/s] Validation DataLoader 0: 100%|██████████████████| 40/40 [00:00<00:00, 92.43it/s] Epoch 2: 100%|█████████████████████| 351/351 [1:11:07<00:00, 0.08it/s, v_num=4] Validation: | | 0/? [00:00<?, ?it/s] Validation: 0%| | 0/40 [00:00<?, ?it/s] Validation DataLoader 0: 0%| | 0/40 [00:00<?, ?it/s] Validation DataLoader 0: 2%|▍ | 1/40 [00:00<00:00, 48.76it/s] Validation DataLoader 0: 5%|▉ | 2/40 [00:00<00:00, 53.88it/s] Validation DataLoader 0: 8%|█▍ | 3/40 [00:00<00:00, 62.70it/s] Validation DataLoader 0: 10%|█▉ | 4/40 [00:00<00:00, 68.84it/s] Validation DataLoader 0: 12%|██▍ | 5/40 [00:00<00:00, 73.17it/s] Validation DataLoader 0: 15%|██▊ | 6/40 [00:00<00:00, 76.17it/s] Validation DataLoader 0: 18%|███▎ | 7/40 [00:00<00:00, 78.55it/s] Validation DataLoader 0: 20%|███▊ | 8/40 [00:00<00:00, 80.68it/s] Validation DataLoader 0: 22%|████▎ | 9/40 [00:00<00:00, 81.66it/s] Validation DataLoader 0: 25%|████▌ | 10/40 [00:00<00:00, 82.67it/s] Validation DataLoader 0: 28%|████▉ | 11/40 [00:00<00:00, 83.73it/s] Validation DataLoader 0: 30%|█████▍ | 12/40 [00:00<00:00, 84.44it/s] Validation DataLoader 0: 32%|█████▊ | 13/40 [00:00<00:00, 85.22it/s] Validation DataLoader 0: 35%|██████▎ | 14/40 [00:00<00:00, 85.30it/s] Validation DataLoader 0: 38%|██████▊ | 15/40 [00:00<00:00, 85.38it/s] Validation DataLoader 0: 40%|███████▏ | 16/40 [00:00<00:00, 85.64it/s] Validation DataLoader 0: 42%|███████▋ | 17/40 [00:00<00:00, 85.80it/s] Validation DataLoader 0: 45%|████████ | 18/40 [00:00<00:00, 85.69it/s] Validation DataLoader 0: 48%|████████▌ | 19/40 [00:00<00:00, 85.56it/s] Validation DataLoader 0: 50%|█████████ | 20/40 [00:00<00:00, 85.55it/s] Validation DataLoader 0: 52%|█████████▍ | 21/40 [00:00<00:00, 85.61it/s] Validation DataLoader 0: 55%|█████████▉ | 22/40 [00:00<00:00, 85.18it/s] Validation DataLoader 0: 57%|██████████▎ | 23/40 [00:00<00:00, 84.76it/s] Validation DataLoader 0: 60%|██████████▊ | 24/40 [00:00<00:00, 84.50it/s] Validation DataLoader 0: 62%|███████████▎ | 25/40 [00:00<00:00, 83.98it/s] Validation DataLoader 0: 65%|███████████▋ | 26/40 [00:00<00:00, 83.51it/s] Validation DataLoader 0: 68%|████████████▏ | 27/40 [00:00<00:00, 82.85it/s] Validation DataLoader 0: 70%|████████████▌ | 28/40 [00:00<00:00, 82.40it/s] Validation DataLoader 0: 72%|█████████████ | 29/40 [00:00<00:00, 81.80it/s] Validation DataLoader 0: 75%|█████████████▌ | 30/40 [00:00<00:00, 81.03it/s] Validation DataLoader 0: 78%|█████████████▉ | 31/40 [00:00<00:00, 80.26it/s] Validation DataLoader 0: 80%|██████████████▍ | 32/40 [00:00<00:00, 79.51it/s] Validation DataLoader 0: 82%|██████████████▊ | 33/40 [00:00<00:00, 78.72it/s] Validation DataLoader 0: 85%|███████████████▎ | 34/40 [00:00<00:00, 77.71it/s] Validation DataLoader 0: 88%|███████████████▊ | 35/40 [00:00<00:00, 76.64it/s] Validation DataLoader 0: 90%|████████████████▏ | 36/40 [00:00<00:00, 75.42it/s] Validation DataLoader 0: 92%|████████████████▋ | 37/40 [00:00<00:00, 74.01it/s] Validation DataLoader 0: 95%|█████████████████ | 38/40 [00:00<00:00, 72.26it/s] Validation DataLoader 0: 98%|█████████████████▌| 39/40 [00:00<00:00, 70.52it/s] Validation DataLoader 0: 100%|██████████████████| 40/40 [00:00<00:00, 68.27it/s] Epoch 3: 28%|██████▋ | 97/351 [12:43<33:19, 0.13it/s, v_num=4]
Complete the following tasks
1. Modify the train_model
function to take the number of epochs as an argument.
2. Train the model for at least 20 epochs and plot the validation and test accuracies.
Look at saved_models/resnet/lightning_logs
for the this information.
3. Modify the train_model
function to load the model from checkpoint
4. Write a function that uses the model trained in step 2 to perform inference in a list of images. The function prints the list of images alongside their true and predicted labels. We can use this function as follows:
image_files = ["1.png", "2.png"]
classify(image_files)
The output of this program will look similar to
Classification results:
1.png dog dog
2.png cat dog
5. (Bonus) the current classification head consists of single fully-connected layer. Let's replace it with a neural network with one hidden layer. This can be achieve by replacing the "fc" layer with a sequence of layers as seen in the following code snippet.
model.fc = nn.Sequential(
nn.BatchNorm1d(num_in_features),
nn.Linear(in_features=num_in_features, out_features=1024, bias=True),
nn.ReLU(),
nn.BatchNorm1d(1024),
nn.Dropout(0.5)
nn.Linear(in_features=1024, out_features=10, bias=True)
)
where num_in_features
simply reflect feature size for the ResNet encoder. You can find this information as follows
num_in_features = resnet_cifar.get_classifier().in_features
Jupyter notebook
Source notebook is available here.
GPU resources¶
(Experimental) You can use the GPU resources available at https://hubdev.science.ontariotechu.ca/ to complete your lab.