lgatr.layers.mlp.nonlinearities.ScalarGatedNonlinearity

class lgatr.layers.mlp.nonlinearities.ScalarGatedNonlinearity(nonlinearity='relu')[source]

Bases: Module

Gated nonlinearity on multivectors.

Given multivector input x, computes f(x_0) * x, where f can either be ReLU, sigmoid, or GeLU.

Auxiliary scalar inputs are simply processed with ReLU, sigmoid, or GeLU, without gating.

Parameters:

nonlinearity ({"relu", "sigmoid", "gelu"}) – Non-linearity type

forward(multivectors, scalars)[source]

Computes f(x_0) * x for multivector x, where f is GELU, ReLU, or sigmoid.

f is chosen depending on self.gated_nonlinearity and self.scalar_nonlinearity.

Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, 16)

  • scalars (None or torch.Tensor) – Input scalars with shape (…)

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, 16)

  • output_scalars (torch.Tensor) – Output scalars with shape (…)