lgatr.layers.mlp.nonlinearities.ScalarGatedNonlinearity
- class lgatr.layers.mlp.nonlinearities.ScalarGatedNonlinearity(nonlinearity='relu')[source]
Bases:
ModuleGated nonlinearity on multivectors.
Given multivector input x, computes
f(x_0) * x, where f can either be ReLU, sigmoid, or GeLU.Auxiliary scalar inputs are simply processed with ReLU, sigmoid, GeLU, or SiLU, without gating.
- Parameters:
nonlinearity ({"relu", "sigmoid", "gelu", "silu"}) – Non-linearity type
- forward(multivectors, scalars)[source]
Computes
f(x_0) * xfor multivector x, where f is GELU, ReLU, sigmoid, or SiLU.f is chosen depending on self.gated_nonlinearity and self.scalar_nonlinearity.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, 16)
scalars (None or torch.Tensor) – Input scalars with shape (…)
- Return type:
tuple[Tensor,Tensor]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, 16)
output_scalars (torch.Tensor) – Output scalars with shape (…)