Source code for itwinai.torch.models.mnist

# --------------------------------------------------------------------------------------
# Part of the interTwin Project: https://www.intertwin.eu/
#
# Created by: Matteo Bunino
#
# Credit:
# - Matteo Bunino <matteo.bunino@cern.ch> - CERN
# --------------------------------------------------------------------------------------

import lightning as L
import torch
import torch.nn as nn
from torch.nn import functional as F


[docs] class MNISTModel(L.LightningModule): """Simple PL model for MNIST. Adapted from https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/mnist-hello-world.html """ def __init__( self, hidden_size: int = 64, ): super().__init__() # Automatically save constructor args as hyperparameters self.save_hyperparameters() # Set our init args as class attributes self.hidden_size = hidden_size # Hardcode some dataset specific attributes self.num_classes = 10 self.dims = (1, 28, 28) channels, width, height = self.dims # Define PyTorch model self.model = nn.Sequential( nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, self.num_classes), )
[docs] def forward(self, x): x = self.model(x) return F.log_softmax(x, dim=1)
[docs] def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) # Log metrics with autolog self.log("train_loss", loss, on_step=True, on_epoch=True) return loss
[docs] def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) # preds = torch.argmax(logits, dim=1) self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
[docs] def test_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) self.log("test_loss", loss)
[docs] def predict_step(self, batch, batch_idx, dataloader_idx=0): x, _ = batch logits = self(x) preds = torch.argmax(logits, dim=1) return preds