Adding a dataset
To add your own dataset or integrate datasets from other libraries,
you can create new component that extend the Dataset class.
This interface specifies abstract methods that need to be implemented:
@frozen
class Dataset(ABC, Component):
"""Abstract base class for datasets."""
@abstractmethod
def get_dataloader(self, b: int, split: Split, num_workers: int) -> DataLoader:
"""Return a dataloader with batch size ``b`` for the training or validation dataset."""
raise NotImplementedError()
@abstractmethod
def data_input_shape(self, batch_size: int) -> torch.Size:
"""The expected input shape of the data, including the batch size"""
raise NotImplementedError()
@abstractmethod
def model_output_shape(self, batch_size: int) -> torch.Size:
"""The expected output shape of the data, the shape of the targets"""
raise NotImplementedError()
@abstractmethod
def get_num_samples(self, split: Split) -> int:
"""The number of samples in the training or validation sets."""
raise NotImplementedError()
def has_validation_set(self) -> bool:
return True
@abstractmethod
def has_test_set(self) -> bool:
"""Does the dataset have a test set."""
raise NotImplementedError()
As an example, we’ll follow the current implementation for MNIST.
Basic information about the dataset
The following 3 methods just return basic statistics about the dataset:
input_shapeoutput_shapeget_num_samples@abstractmethod def get_num_samples(self, split: Split) -> int: """The number of samples in the training or validation sets.""" raise NotImplementedError()
Making dataloaders
The last method, get_dataloader, returns the actual torch.utils.data.DataLoader
that will be used to iterate over the dataset.
@abstractmethod def get_dataloader(self, b: int, split: Split, num_workers: int) -> DataLoader: """Return a dataloader with batch size ``b`` for the training or validation dataset.""" raise NotImplementedError()def make_dataloader(dataset, b, num_workers): return torch.utils.data.DataLoader( dataset, batch_size=b, shuffle=True, drop_last=False, num_workers=num_workers, pin_memory=True, )
Adding further integrations
Additional features can be implemented by inheriting from the following classes.
For example, to define a dataset that can be downloaded from the internet
when using the prepare CLI option, extend both Dataset
and Downloadable
and implement their respective abstract methods;
class MyDataset(Dataset, Downloadable):
...
The Downloadable extension is useful,
the other are very much optional.
Downloadble datasets
class Downloadable:
"""Extension for datasets that can be downloaded.
For interaction with the ``prepare`` CLI command.
"""
@abstractmethod
def is_downloaded(self) -> bool:
"""Checks whether the dataset is already in the workspace."""
raise NotImplementedError()
@abstractmethod
def download(self):
"""Downloads the dataset to the workspace."""
raise NotImplementedError()
Datasets with class counts
class HasClassCounts:
"""Extension for datasets that provide class frequencies."""
@abstractmethod
def class_counts(self, split: Split) -> torch.Tensor:
raise NotImplementedError()
Datasets that can be put in RAM
class InMemory:
"""Extension for small datasets that can be loaded directly into RAM."""
@abstractmethod
def get_in_memory_dataloader(
self,
b: int,
split: Split,
num_workers: int,
to_device: Optional[Device] = None,
) -> torch.utils.data.DataLoader:
"""Returns a Dataloader with the dataset already loaded into RAM on the device."""
raise NotImplementedError()
Datasets that need to be moved to local storage
class MovableToLocal:
"""Extension for large datasets that need to be moved to local storage on SLURM nodes."""
@abstractmethod
def is_on_local(self) -> bool:
"""Returns True if the dataset is already in local storage."""
raise NotImplementedError()
@abstractmethod
def move_to_local(self):
"""Moves the dataset from the workspace to local storage."""
raise NotImplementedError()
Not implemented by MNIST as the dataset is small.