from abc import ABC, abstractmethod
import torch
from attrs import frozen
from optexp.component import Component
def assert_batch_sizes_match(b1: int, b2: int) -> None:
assert b1 == b2, f"Batch sizes do not match: {b1} != {b2}"
[docs]
@frozen
class Model(Component, ABC):
"""Abstract base class for models."""
@abstractmethod
def load_model(
self, input_shape: torch.Size, output_shape: torch.Size
) -> torch.nn.Module:
"""Returns a :class:`torch.nn.Module` for the specified model."""
raise NotImplementedError()