"""Multivector normalization."""
import torch
from .invariants import abs_squared_norm
[docs]
def equi_layer_norm(
x: torch.Tensor, channel_dim: int = -2, gain: float = 1.0, epsilon: float = 0.01
) -> torch.Tensor:
"""Equivariant LayerNorm for multivectors.
Rescales input such that ``mean_channels |inputs|^2 = 1``, where the norm is the GA norm and the
mean goes over the channel dimensions.
Using a factor ``gain > 1`` makes up for the fact that the GP norm overestimates the actual
standard deviation of the input data.
Parameters
----------
x : torch.Tensor
Multivectors with shape (..., 16).
channel_dim : int
Channel dimension index. Defaults to the second-last entry (last are the multivector components).
gain : float
Target output scale.
epsilon : float
Small numerical factor to avoid instabilities. By default, we use a reasonably large number
to balance issues that arise from some multivector components not contributing to the norm.
Returns
-------
outputs : torch.Tensor
Normalized multivectors with shape (..., 16).
"""
# Compute mean_channels |inputs|^2
abs_squared_norms = abs_squared_norm(x)
abs_squared_norms = torch.mean(abs_squared_norms, dim=channel_dim, keepdim=True)
# Insure against low-norm tensors (which can arise even when `x.var(dim=-1)` is high b/c some
# entries don't contribute to the inner product / GP norm!)
abs_squared_norms = torch.clamp(abs_squared_norms, epsilon)
# Rescale inputs
outputs = gain * x / torch.sqrt(abs_squared_norms)
return outputs