Source code for optexp.optim.sgd

import torch
from attrs import frozen

from optexp.optim.optimizer import Optimizer, Regularizable, WeightDecayStrategy
from optexp.optim.weight_decay_strategies import DecayEverything
from optexp.plotting.colors import Colors


[docs] @frozen class SGD(Optimizer, Regularizable): """Stochastic Gradient Descent. Args: lr (float): learning rate. momentum (float, optional): momentum. Defaults to 0 dampening (float, optional): dampening for momentum. Defaults to 0 weight_decay (float, optional): weight decay (L2 penalty). Defaults to 0 nesterov (bool, optional): enables Nesterov momentum. Defaults to False decay_strategy (WeightDecayStrategy, optional): The strategy for applying weight decay. Defaults to DecayEverything(). """ lr: float momentum: float = 0 dampening: float = 0 weight_decay: float = 0 nesterov: bool = False decay_strategy: WeightDecayStrategy = DecayEverything() def load(self, model: torch.nn.Module) -> torch.optim.Optimizer: param_groups = self.decay_strategy.make_param_groups(model, self.weight_decay) return torch.optim.SGD( param_groups, lr=self.lr, momentum=self.momentum, dampening=self.dampening, weight_decay=self.weight_decay, nesterov=self.nesterov, ) def regularizer_loss(self, model: torch.nn.Module) -> torch.Tensor: return self.decay_strategy.regularizer_loss(model, self.weight_decay) def plot_label(self) -> str: attributes = [] if self.lr is not None and self.lr != 0: attributes.append(rf"$\alpha={self.lr:.3g}$") if self.momentum is not None and self.momentum != 0: attributes.append(rf"$\beta={self.momentum:.3g}$") if self.weight_decay is not None and self.weight_decay != 0: attributes.append(rf"$\lambda={self.weight_decay:.3g}$") return "GD (" + " ".join(attributes) + ")" def plot_style(self): return { "color": Colors.HighContrast.blue, "linestyle": "-" if self.momentum > 0 else "--", }