Source code for optexp.datasets.dummy

from typing import Optional

import torch
from torch.types import Device
from torch.utils.data import TensorDataset

from optexp.datasets.dataset import Dataset, Split
from optexp.datasets.utils import make_dataloader


def make_dataset(n):
    x = torch.linspace(0, 1, n)
    y = x + 0.1 * torch.sin(10 * 2 * torch.pi * x)
    return x.reshape(-1, 1), y.reshape(-1, 1)


[docs] class DummyRegression(Dataset): """A dummy dataset for testing purposes. The results is generated by a sine function:: y = x + 0.1 * sin(10 * 2 * pi * x) The inputs x form a linear grid from 0 to 1 (``linspace(0, 1, n_samples)``). Args: n_tr (int, optional): number of training samples. Defaults to 100. n_va (int, optional): number of validation samples. Defaults to 10. """ n_tr: int = 100 n_va: int = 10 def get_num_samples(self, split: Split) -> int: if split == "tr": return self.n_tr if split == "va": return self.n_va raise ValueError(f"Invalid split: {split}") def data_input_shape(self, batch_size) -> torch.Size: return torch.Size([batch_size, 1]) def model_output_shape(self, batch_size) -> torch.Size: return torch.Size([batch_size, 1]) def has_test_set(self) -> bool: return False @staticmethod def _get_dataset(split: Split, to_device: Optional[Device] = None): if split == "tr": data, targets = make_dataset(DummyRegression.n_tr) elif split == "va": data, targets = make_dataset(DummyRegression.n_va) else: raise ValueError(f"Invalid tr_va: {split}") if to_device is not None: data, targets = data.to(to_device), targets.to(to_device) return TensorDataset(data, targets) def get_dataloader( self, b: int, split: Split, num_workers: int, to_device: Optional[Device] = None, ) -> torch.utils.data.DataLoader: return make_dataloader( self._get_dataset(split, to_device=to_device), b, num_workers ) def get_tensor_dataloader( self, b: int, split: Split, num_workers: int, to_device: Optional[Device] = None, ) -> torch.utils.data.DataLoader: return self.get_dataloader(b, split, num_workers, to_device=to_device)