Source code for optexp.datasets.image.mnist

from functools import lru_cache
from typing import Optional

import torch
import torchvision
from attrs import frozen
from torch.types import Device
from torch.utils.data import TensorDataset

from optexp.config import Config
from optexp.datasets.dataset import (
    Dataset,
    Downloadable,
    HasClassCounts,
    InMemory,
    Split,
)
from optexp.datasets.utils import make_dataloader

MEAN, STD = 0.1307, 0.3081


[docs] @frozen class MNIST(Dataset, HasClassCounts, Downloadable, InMemory): """The `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset, provided through `TorchVision <https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html>`_. """ def get_num_samples(self, split: Split) -> int: if split == "tr": return 60_000 if split == "va": return 10_000 raise ValueError(f"Invalid tr_va: {split}") @lru_cache() def class_counts(self, split: Split) -> torch.Tensor: return torch.bincount(self._get_dataset(split).targets) def data_input_shape(self, batch_size) -> torch.Size: return torch.Size([batch_size, 1, 28, 28]) def model_output_shape(self, batch_size) -> torch.Size: return torch.Size([batch_size, 10]) def has_test_set(self) -> bool: return False def is_downloaded(self): return all( (Config.get_dataset_directory() / "MNISTDataset" / "raw" / file).exists() for file in [ "train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte", ] ) def download(self): path = str(Config.get_dataset_directory()) for train in [True, False]: torchvision.datasets.MNIST(path, download=True, train=train) @staticmethod def _get_dataset(split: Split): return torchvision.datasets.MNIST( root=str(Config.get_dataset_directory()), train=split == "tr", transform=torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=MEAN, std=STD), ] ), ) def _get_tensor_dataset(self, split: Split, to_device: Optional[Device] = None): dataset = self._get_dataset(split) data, targets = dataset.data, dataset.targets data = torchvision.transforms.Normalize(mean=MEAN, std=STD)(data / 255.0) data = data.unsqueeze(1) if to_device is not None: data, targets = data.to(to_device), targets.to(to_device) return TensorDataset(data, targets) def get_dataloader( self, b: int, split: Split, num_workers: int ) -> torch.utils.data.DataLoader: return make_dataloader(self._get_dataset(split), b, num_workers) def get_in_memory_dataloader( self, b: int, split: Split, num_workers: int, to_device: Optional[Device] = None, ) -> torch.utils.data.DataLoader: return make_dataloader( self._get_tensor_dataset(split, to_device), b, num_workers )