Source code for lgatr.primitives.nonlinearities

"""Gated nonlinearities on multivectors."""
import torch


[docs] def gated_relu(x: torch.Tensor, gates: torch.Tensor) -> torch.Tensor: """Pin-equivariant gated ReLU nonlinearity. Given multivector input x and scalar input gates (with matching batch dimensions), computes ``ReLU(gates) * x``. Parameters ---------- x : torch.Tensor Multivector input with shape (..., 16) gates : torch.Tensor Pin-invariant gates with shape (..., 1). Returns ------- outputs : torch.Tensor Computes ReLU(gates) * x, with broadcasting along the last dimension. Result has shape (..., 16) """ weights = torch.nn.functional.relu(gates) outputs = weights * x return outputs
[docs] def gated_sigmoid(x: torch.Tensor, gates: torch.Tensor): """Pin-equivariant gated sigmoid nonlinearity. Given multivector input x and scalar input gates (with matching batch dimensions), computes ``sigmoid(gates) * x``. Parameters ---------- x : torch.Tensor Multivector input with shape (..., 16) gates : torch.Tensor Pin-invariant gates with shape (..., 1). Returns ------- outputs : torch.Tensor Computes Sigmoid(gates) * x, with broadcasting along the last dimension. Result has shape (..., 16) """ weights = torch.nn.functional.sigmoid(gates) outputs = weights * x return outputs
[docs] def gated_gelu(x: torch.Tensor, gates: torch.Tensor) -> torch.Tensor: """Pin-equivariant gated GeLU nonlinearity without division. Given multivector input x and scalar input gates (with matching batch dimensions), computes ``GeLU(gates) * x``. References ---------- Dan Hendrycks, Kevin Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415 Parameters ---------- x : torch.Tensor Multivector input with shape (..., 16) gates : torch.Tensor Pin-invariant gates with shape (..., 1). Returns ------- outputs : torch.Tensor Computes GeLU(gates) * x, with broadcasting along the last dimension. Result has shape (..., 16) """ weights = torch.nn.functional.gelu(gates, approximate="tanh") outputs = weights * x return outputs