An Easy Way to Use ResNet
Although the code for ResNet-50 in the previous part seems to work well, but it is not very convenient to use, and you definitely don't want to write the same code every time you want to use ResNet-50.
In this part, we will show you how to use ResNet in a more convenient way. Since ResNet is actually a very popular model, PyTorch provides a pre-trained ResNet model in its library. We can use it directly.
Training ResNet-50 on CIFAR-10
First, let's import the necessary libraries.
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torch import nn, optim
Then, we can transform the data and load the CIFAR-10 dataset.
class RandomResize(transforms.Resize):
def __init__(self, min_size=256, max_size=480, interpolation=transforms.InterpolationMode.BILINEAR):
super().__init__(size=1)
self.min_size = min_size
self.max_size = max_size
self.interpolation = interpolation
def forward(self, img):
size = torch.randint(self.min_size, self.max_size + 1, (1,)).item()
short, long = min(img.size[1], img.size[0]), max(img.size[1], img.size[0])
if short == img.size[0]:
ow, oh = size, int(size * long / short)
else:
ow, oh = int(size * long / short), size
return F.resize(img, (oh, ow), self.interpolation)
mean = [0.4914, 0.4822, 0.4465]
std = [1.0, 1.0, 1.0]
transform = transforms.Compose([
RandomResize(),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
Now, we can load the pre-trained ResNet-50 model and modify the last layer to fit the CIFAR-10 dataset.
model = resnet50(pretrained=False)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(model.fc.in_features, 10)
Then, we can train the model as usual.
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
train_losses = []
val_losses = []
for epoch in range(100):
model.train()
train_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
train_losses.append(avg_train_loss)
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
avg_val_loss = val_loss / len(test_loader)
val_losses.append(avg_val_loss)
print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}')
Use Pre-trained ResNet-50
We need to transform the data in CIFAR-10 to the format that the pre-trained ResNet-50 expects.
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)
We can set the pretrained
parameter to True
to load the pre-trained ResNet-50 model.
model = resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 10)
model.eval()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0