Adding a model

To add your own model or integrate models from other libraries, you can create new component that extend the Model class.

This interface specifies abstract methods that need to be implemented:

@frozen
class Model(Component, ABC):
    """Abstract base class for models."""

    @abstractmethod
    def load_model(
        self, input_shape: torch.Size, output_shape: torch.Size
    ) -> torch.nn.Module:
        """Returns a :class:`torch.nn.Module` for the specified model."""
        raise NotImplementedError()

As an example, we’ll follow the current implementation for LeNet5:

@frozen
class LeNet5(Model):
    """A basic convolutional neural network for image classification from [LeCun1998]_.

    The model expects images of shape [batch, channels, 32, 32].
    If images are 28x28, the model will pad the images to 32x32.

    .. [LeCun1998] Gradient Based Learning Applied to Document Recognition.
       Yann LeCun, Leon Bottou, Yoshua Bengio, and Patrick Haffner.
       Proceedings of the IEEE, 86(11):2278-2324, 1998.
       `DOI: 10.1109/5.726791 <https://doi.org/10.1109/5.726791>`_
    """

    def load_model(self, input_shape, output_shape):
        validate_image_data(input_shape, output_shape)

        b_in, channels, _, _ = input_shape
        b_out, num_classes = output_shape
        assert_batch_sizes_match(b_in, b_out)

        class LeNet5Module(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(channels, 6, 5)
                self.conv2 = torch.nn.Conv2d(6, 16, 5)
                self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
                self.fc2 = torch.nn.Linear(120, 84)
                self.fc3 = torch.nn.Linear(84, num_classes)

            def forward(self, x):
                if x.shape[-2:] == (28, 28):
                    pad = torchvision.transforms.Pad(2, fill=0, padding_mode="constant")
                    x = pad(x)
                else:
                    if not x.shape[-2:] == (32, 32):
                        raise ValueError(
                            f"Input shape must be 28x28 or 32x32. Got {x.shape}"
                        )

                x = F.max_pool2d(F.relu(self.conv1(x)), 2)
                x = F.max_pool2d(F.relu(self.conv2(x)), 2)
                x = torch.flatten(x, 1)
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
                output = self.fc3(x)

                return output

        return LeNet5Module()