lgatr.layers.mlp.nonlinearities.ScalarGatedNonlinearity
- class lgatr.layers.mlp.nonlinearities.ScalarGatedNonlinearity(nonlinearity='relu')[source]
Bases:
Module
Gated nonlinearity on multivectors.
Given multivector input x, computes
f(x_0) * x
, where f can either be ReLU, sigmoid, or GeLU.Auxiliary scalar inputs are simply processed with ReLU, sigmoid, or GeLU, without gating.
- Parameters:
nonlinearity ({"relu", "sigmoid", "gelu"}) – Non-linearity type
- forward(multivectors, scalars)[source]
Computes
f(x_0) * x
for multivector x, where f is GELU, ReLU, or sigmoid.f is chosen depending on self.gated_nonlinearity and self.scalar_nonlinearity.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, 16)
scalars (None or torch.Tensor) – Input scalars with shape (…)
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, 16)
output_scalars (torch.Tensor) – Output scalars with shape (…)