Source code for optexp.hardwareconfig.hardwareconfig

from abc import ABC, abstractmethod
from typing import Literal

from attrs import frozen

from optexp.component import Component
from optexp.problem import Problem


@frozen
class BatchSizeInfo:
    mbatchsize_tr: int
    mbatchsize_va: int
    accumulation_steps: int
    workers_tr: int
    workers_va: int


[docs] @frozen class HardwareConfig(Component, ABC): """Abstract base class for hardware configurations.""" @abstractmethod def get_batch_size_info(self, problem: Problem) -> BatchSizeInfo: raise NotImplementedError @abstractmethod def get_num_devices(self): raise NotImplementedError @abstractmethod def get_accelerator(self) -> Literal["cpu", "cuda"]: raise NotImplementedError