lgatr.layers.layer_norm.EquiLayerNorm

class lgatr.layers.layer_norm.EquiLayerNorm(mv_channel_dim=-2, epsilon=0.01)[source]

Bases: Module

Layer normalization.

Rescales input such that mean_channels |inputs|^2 = 1, where the norm is the GA norm and the mean goes over the channel dimensions.

In addition, the layer performs a regular LayerNorm operation on auxiliary scalar inputs.

Parameters:
  • mv_channel_dim (int) – Channel dimension index for multivector inputs. Defaults to the second-last entry (last are the multivector components).

  • epsilon (float) – Small numerical factor to avoid instabilities. We use a reasonably large number to balance issues that arise from some multivector components not contributing to the norm.

forward(multivectors, scalars)[source]

Forward pass. Computes equivariant LayerNorm for multivectors.

Parameters:
  • multivectors (torch.Tensor) – Multivector inputs with shape (…, 16).

  • scalars (torch.Tensor) – Scalar inputs with shape (…).

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Normalized multivectors with shape (…, 16).

  • output_scalars (torch.Tensor) – Normalized scalars with shape (…).