Advanced Topics
An Easy Way to Use ResNet

An Easy Way to Use ResNet

Although the code for ResNet-50 in the previous part seems to work well, but it is not very convenient to use, and you definitely don't want to write the same code every time you want to use ResNet-50.

In this part, we will show you how to use ResNet in a more convenient way. Since ResNet is actually a very popular model, PyTorch provides a pre-trained ResNet model in its library. We can use it directly.

Training ResNet-50 on CIFAR-10

First, let's import the necessary libraries.

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch import nn, optim

Then, we can transform the data and load the CIFAR-10 dataset.

class RandomResize(transforms.Resize):
    def __init__(self, min_size=256, max_size=480, interpolation=transforms.InterpolationMode.BILINEAR):
        super().__init__(size=1)
        self.min_size = min_size
        self.max_size = max_size
        self.interpolation = interpolation
 
    def forward(self, img):
        size = torch.randint(self.min_size, self.max_size + 1, (1,)).item()
        short, long = min(img.size[1], img.size[0]), max(img.size[1], img.size[0])
        if short == img.size[0]:
            ow, oh = size, int(size * long / short)
        else:
            ow, oh = int(size * long / short), size
        return F.resize(img, (oh, ow), self.interpolation)
 
mean = [0.4914, 0.4822, 0.4465]
std = [1.0, 1.0, 1.0]
 
transform = transforms.Compose([
    RandomResize(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])
 
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)
 
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)

Now, we can load the pre-trained ResNet-50 model and modify the last layer to fit the CIFAR-10 dataset.

model = resnet50(pretrained=False)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(model.fc.in_features, 10)

Then, we can train the model as usual.

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
 
train_losses = []
val_losses = []
 
for epoch in range(100):
    model.train()
    train_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
 
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
 
        train_loss += loss.item()
 
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
 
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
 
    avg_val_loss = val_loss / len(test_loader)
    val_losses.append(avg_val_loss)
 
    print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}')
 

Use Pre-trained ResNet-50

We need to transform the data in CIFAR-10 to the format that the pre-trained ResNet-50 expects.

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
 
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
 
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)

We can set the pretrained parameter to True to load the pre-trained ResNet-50 model.

model = resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 10)
model.eval()
 
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
model.train()
for epoch in range(50):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0