import wandb
import torch, torchvision
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as T
MNIST.mirrors = [
mirror for mirror in MNIST.mirrors if "http://yann.lecun.com/" not in mirror
]
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def get_dataloader(is_train, batch_size, slice=5):
"Get a training dataloader"
full_dataset = MNIST(
root=".", train=is_train, transform=T.ToTensor(), download=True
)
sub_dataset = torch.utils.data.Subset(
full_dataset, indices=range(0, len(full_dataset), slice)
)
loader = torch.utils.data.DataLoader(
dataset=sub_dataset,
batch_size=batch_size,
shuffle=True if is_train else False,
pin_memory=True,
num_workers=2,
)
return loader
def get_model(dropout):
"A simple model"
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(256, 10),
).to(device)
return model
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
"Compute performance of the model on the validation dataset and log a wandb.Table"
model.eval()
val_loss = 0.0
with torch.inference_mode():
correct = 0
for i, (images, labels) in enumerate(valid_dl):
images, labels = images.to(device), labels.to(device)
# Forward pass ➡
outputs = model(images)
val_loss += loss_func(outputs, labels) * labels.size(0)
# Compute accuracy and accumulate
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
# Log one batch of images to the dashboard, always same batch_idx.
if i == batch_idx and log_images:
log_image_table(images, predicted, labels, outputs.softmax(dim=1))
return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)