Source code for optexp.datasets.dataset

from abc import ABC, abstractmethod
from typing import Literal, Optional

import torch
from attrs import frozen
from torch.types import Device
from torch.utils.data import DataLoader

from optexp.component import Component

Split = Literal["tr", "va", "te"]
splits: list[Split] = ["tr", "va", "te"]


[docs] @frozen class Dataset(ABC, Component): """Abstract base class for datasets.""" @abstractmethod def get_dataloader(self, b: int, split: Split, num_workers: int) -> DataLoader: """Return a dataloader with batch size ``b`` for the training or validation dataset.""" raise NotImplementedError() @abstractmethod def data_input_shape(self, batch_size: int) -> torch.Size: """The expected input shape of the data, including the batch size""" raise NotImplementedError() @abstractmethod def model_output_shape(self, batch_size: int) -> torch.Size: """The expected output shape of the data, the shape of the targets""" raise NotImplementedError() @abstractmethod def get_num_samples(self, split: Split) -> int: """The number of samples in the training or validation sets.""" raise NotImplementedError() def has_validation_set(self) -> bool: return True @abstractmethod def has_test_set(self) -> bool: """Does the dataset have a test set.""" raise NotImplementedError()
class HasClassCounts: """Extension for datasets that provide class frequencies.""" @abstractmethod def class_counts(self, split: Split) -> torch.Tensor: raise NotImplementedError() class HasInputCounts: """Extension for datasets that provide input feature frequencies.""" @abstractmethod def input_counts(self, split: Split) -> torch.Tensor: raise NotImplementedError() class MovableToLocal: """Extension for large datasets that need to be moved to local storage on SLURM nodes.""" @abstractmethod def is_on_local(self) -> bool: """Returns True if the dataset is already in local storage.""" raise NotImplementedError() @abstractmethod def move_to_local(self): """Moves the dataset from the workspace to local storage.""" raise NotImplementedError() class Downloadable: """Extension for datasets that can be downloaded. For interaction with the ``prepare`` CLI command. """ @abstractmethod def is_downloaded(self) -> bool: """Checks whether the dataset is already in the workspace.""" raise NotImplementedError() @abstractmethod def download(self): """Downloads the dataset to the workspace.""" raise NotImplementedError() class InMemory: """Extension for small datasets that can be loaded directly into RAM.""" @abstractmethod def get_in_memory_dataloader( self, b: int, split: Split, num_workers: int, to_device: Optional[Device] = None, ) -> torch.utils.data.DataLoader: """Returns a Dataloader with the dataset already loaded into RAM on the device.""" raise NotImplementedError()