"""Equivariant transformer for vector and scalar data."""
import math
import torch
from torch import nn
from torch.nn.functional import dropout, dropout1d
from torch.utils.checkpoint import checkpoint
from ..primitives.attention import scaled_dot_product_attention
from ..utils.misc import minimum_autocast_precision
def inner_product(x, y):
t = x[..., 0] * y[..., 0]
s = (x[..., 1:] * y[..., 1:]).sum(dim=-1)
return t - s
def squared_norm(x):
return inner_product(x, x)
def get_nonlinearity(label):
if label == "relu":
return nn.ReLU()
elif label == "sigmoid":
return nn.Sigmoid()
elif label == "tanh":
return nn.Tanh()
elif label == "gelu":
return nn.GELU()
elif label == "silu":
return nn.SiLU()
else:
raise ValueError(f"Unsupported nonlinearity type: {label}")
[docs]
class Dropout(nn.Module):
"""Dropout module for scalar and vector features.
For vector features, the same dropout mask is applied to all four components of each vector.
"""
def __init__(self, dropout_prob: float):
super().__init__()
self._dropout_prob = dropout_prob
[docs]
def forward(self, vectors, scalars):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the dropped out vectors and scalars.
"""
# have to reshape vectors because dropout1d constrains input shape
v = vectors.reshape(-1, 4)
out_v = dropout1d(v, p=self._dropout_prob, training=self.training)
out_v = out_v.reshape(vectors.shape)
out_s = dropout(scalars, p=self._dropout_prob, training=self.training)
return out_v, out_s
[docs]
class RMSNorm(nn.Module):
"""Normalize jointly over vector and scalar features.
For vectors, we use the absolute value of the squared norm because otherwise negative norms are possible.
"""
def __init__(self, epsilon: float = 0.01):
super().__init__()
self.epsilon = epsilon
[docs]
@minimum_autocast_precision(torch.float32)
def forward(self, vectors, scalars):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
v_squared_norm = squared_norm(vectors).abs()
s_squared_norm = scalars.square()
sum_squared_norms = v_squared_norm.sum(dim=-1) + s_squared_norm.sum(dim=-1)
mean_squared_norms = sum_squared_norms / (vectors.shape[-2] + scalars.shape[-1])
norm = torch.rsqrt(mean_squared_norms + self.epsilon).unsqueeze(-1)
vectors_out = vectors * norm.unsqueeze(-1)
scalars_out = scalars * norm
return vectors_out, scalars_out
[docs]
class Linear(nn.Module):
"""Linear operations for vector and scalar features.
Supports optional mixing between vector and scalar features to improve expressivity.
"""
def __init__(
self,
in_v_channels: int,
out_v_channels: int,
in_s_channels: int,
out_s_channels: int,
bias: bool = True,
initialization: str = "default",
):
"""
Parameters
----------
in_v_channels : int
Number of input vector channels.
out_v_channels : int
Number of output vector channels.
in_s_channels : int
Number of input scalar channels.
out_s_channels : int
Number of output scalar channels.
bias : bool, optional
Whether to include a bias term in the scalar linear layer, by default True.
initialization : str, optional
Initialization method for weights, by default "default".
The alternative "small" initializes weights to smaller values,
which might improve stability in attention projections.
"""
super().__init__()
self._in_v_channels = in_v_channels
self._out_v_channels = out_v_channels
self._in_s_channels = in_s_channels
self._out_s_channels = out_s_channels
self._bias = bias
self.weight_v = nn.Parameter(
torch.empty(
(
out_v_channels,
in_v_channels,
)
)
)
self.linear_s = nn.Linear(in_s_channels, out_s_channels, bias=bias)
self.reset_parameters(initialization)
[docs]
def forward(self, vectors, scalars):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
vectors_out = self.weight_v @ vectors
scalars_out = self.linear_s(scalars)
return vectors_out, scalars_out
[docs]
def reset_parameters(self, initialization, additional_factor=1.0):
if initialization == "default":
v_factor = additional_factor
s_factor = additional_factor
elif initialization == "small":
v_factor = 0.1 * additional_factor
s_factor = 0.1 * additional_factor
else:
raise ValueError(f"Unknown initialization: {initialization}")
fan_in = max(self._in_v_channels, 1)
bound = v_factor / math.sqrt(fan_in)
nn.init.uniform_(self.weight_v, a=-bound, b=bound)
fan_in = max(self._in_s_channels, 1)
bound = s_factor / math.sqrt(fan_in)
nn.init.uniform_(self.linear_s.weight, a=-bound, b=bound)
[docs]
class GatedLinearUnit(nn.Module):
"""Gated linear unit (GLU) for vector and scalar features.
Scalar gates are computed from scalar features,
while vector gates are computed from inner products of vector features.
"""
def __init__(
self,
in_v_channels: int,
out_v_channels: int,
in_s_channels: int,
out_s_channels: int,
nonlinearity: str = "gelu",
):
super().__init__()
self.linear = Linear(
in_v_channels=in_v_channels,
out_v_channels=3 * out_v_channels,
in_s_channels=in_s_channels,
out_s_channels=2 * out_s_channels,
)
self.nonlinearity = get_nonlinearity(nonlinearity)
[docs]
def forward(self, vectors, scalars):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
v_full, s_full = self.linear(vectors, scalars)
v_pre, v_gates_1, v_gates_2 = v_full.chunk(3, dim=-2)
s_pre, s_gates = s_full.chunk(2, dim=-1)
v_gates = inner_product(v_gates_1, v_gates_2).unsqueeze(-1)
vectors_out = self.nonlinearity(v_gates) * v_pre
scalars_out = self.nonlinearity(s_gates) * s_pre
return vectors_out, scalars_out
[docs]
class SelfAttention(nn.Module):
"""Self-attention module for Lorentz vectors and scalar features."""
def __init__(
self,
v_channels: int,
s_channels: int,
num_heads: int,
attn_ratio: int = 1,
dropout_prob: float | None = None,
):
super().__init__()
self.hidden_v_channels = max(attn_ratio * v_channels // num_heads, 1)
self.hidden_s_channels = max(attn_ratio * s_channels // num_heads, 4)
self.num_heads = num_heads
metric = torch.tensor([1.0, -1.0, -1.0, -1.0])
self.register_buffer("metric", metric)
self.linear_in = Linear(
in_v_channels=v_channels,
out_v_channels=3 * self.hidden_v_channels * self.num_heads,
in_s_channels=s_channels,
out_s_channels=3 * self.hidden_s_channels * self.num_heads,
initialization="small",
)
self.linear_out = Linear(
in_v_channels=self.hidden_v_channels * self.num_heads,
out_v_channels=v_channels,
in_s_channels=self.hidden_s_channels * self.num_heads,
out_s_channels=s_channels,
initialization="small",
)
self.norm = RMSNorm()
if dropout_prob is not None:
self.dropout = Dropout(dropout_prob)
else:
self.dropout = None
def _pre_reshape(self, qkv_v, qkv_s):
qkv_v = (
qkv_v.unflatten(-2, (3, self.hidden_v_channels, self.num_heads))
.movedim(-4, 0)
.movedim(-2, -4)
)
qkv_s = (
qkv_s.unflatten(-1, (3, self.hidden_s_channels, self.num_heads))
.movedim(-3, 0)
.movedim(-1, -3)
)
# normalize for stability (important)
qkv_v, qkv_s = self.norm(qkv_v, qkv_s)
q_v, k_v, v_v = qkv_v.unbind(0)
q_s, k_s, v_s = qkv_s.unbind(0)
q_v_mod = q_v * self.metric
q = torch.cat([q_v_mod.flatten(start_dim=-2), q_s], dim=-1)
k = torch.cat([k_v.flatten(start_dim=-2), k_s], dim=-1)
v = torch.cat([v_v.flatten(start_dim=-2), v_s], dim=-1)
return q, k, v
def _post_reshape(self, out):
h_v = out[..., : self.hidden_v_channels * 4].reshape(
*out.shape[:-1], self.hidden_v_channels, 4
)
h_s = out[..., self.hidden_v_channels * 4 :]
h_v = h_v.movedim(-3, -4).flatten(-3, -2)
h_s = h_s.movedim(-2, -3).flatten(-2, -1)
return h_v, h_s
[docs]
def forward(self, vectors, scalars, **attn_kwargs):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
**attn_kwargs : dict
Additional keyword arguments for the attention function.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
qkv_v, qkv_s = self.linear_in(vectors, scalars)
q, k, v = self._pre_reshape(qkv_v, qkv_s)
out = scaled_dot_product_attention(q, k, v, **attn_kwargs)
h_v, h_s = self._post_reshape(out)
out_v, out_s = self.linear_out(h_v, h_s)
if self.dropout is not None:
out_v, out_s = self.dropout(out_v, out_s)
return out_v, out_s
[docs]
class MLP(nn.Module):
"""Multi-layer perceptron (MLP) for vector and scalar features."""
def __init__(
self,
v_channels: int,
s_channels: int,
nonlinearity: str = "gelu",
mlp_ratio: int = 2,
num_layers: int = 2,
dropout_prob: float | None = None,
):
super().__init__()
assert num_layers >= 2
layers = []
v_channels_list = [v_channels] + [mlp_ratio * v_channels] * (num_layers - 1) + [v_channels]
s_channels_list = [s_channels] + [mlp_ratio * s_channels] * (num_layers - 1) + [s_channels]
for i in range(num_layers - 1):
layers.append(
GatedLinearUnit(
in_v_channels=v_channels_list[i],
out_v_channels=v_channels_list[i + 1],
in_s_channels=s_channels_list[i],
out_s_channels=s_channels_list[i + 1],
nonlinearity=nonlinearity,
)
)
if dropout_prob is not None:
layers.append(Dropout(dropout_prob))
layers.append(
Linear(
in_v_channels=v_channels_list[-2],
out_v_channels=v_channels_list[-1],
in_s_channels=s_channels_list[-2],
out_s_channels=s_channels_list[-1],
)
)
self.layers = nn.ModuleList(layers)
[docs]
def forward(self, vectors, scalars):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
v, s = vectors, scalars
for layer in self.layers:
v, s = layer(v, scalars=s)
return v, s
[docs]
class LGATrSlimBlock(nn.Module):
"""A single block of the L-GATr-slim,
consisting of self-attention and MLP layers, pre-norm and residual connections."""
def __init__(
self,
v_channels: int,
s_channels: int,
num_heads: int,
nonlinearity: str = "gelu",
mlp_ratio: int = 2,
attn_ratio: int = 1,
num_layers_mlp: int = 2,
dropout_prob: float | None = None,
):
super().__init__()
self.norm = RMSNorm()
self.attention = SelfAttention(
v_channels=v_channels,
s_channels=s_channels,
num_heads=num_heads,
attn_ratio=attn_ratio,
dropout_prob=dropout_prob,
)
self.mlp = MLP(
v_channels=v_channels,
s_channels=s_channels,
nonlinearity=nonlinearity,
mlp_ratio=mlp_ratio,
num_layers=num_layers_mlp,
dropout_prob=dropout_prob,
)
[docs]
def forward(self, vectors, scalars, **attn_kwargs):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
**attn_kwargs : dict
Additional keyword arguments for the attention function.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
h_v, h_s = self.norm(vectors, scalars)
h_v, h_s = self.attention(
h_v,
h_s,
**attn_kwargs,
)
outputs_v = vectors + h_v
outputs_s = scalars + h_s
h_v, h_s = self.norm(outputs_v, outputs_s)
h_v, h_s = self.mlp(h_v, h_s)
outputs_v = outputs_v + h_v
outputs_s = outputs_s + h_s
return outputs_v, outputs_s
[docs]
class LGATrSlim(nn.Module):
"""L-GATr-slim network."""
def __init__(
self,
in_v_channels: int,
out_v_channels: int,
hidden_v_channels: int,
in_s_channels: int,
out_s_channels: int,
hidden_s_channels: int,
num_blocks: int,
num_heads: int,
nonlinearity: str = "gelu",
mlp_ratio: int = 2,
attn_ratio: int = 1,
num_layers_mlp: int = 2,
dropout_prob: float | None = None,
checkpoint_blocks: bool = False,
compile: bool = False,
):
"""
Parameters
----------
in_v_channels : int
Number of input vector channels.
out_v_channels : int
Number of output vector channels.
hidden_v_channels : int
Number of hidden vector channels.
in_s_channels : int
Number of input scalar channels.
out_s_channels : int
Number of output scalar channels.
hidden_s_channels : int
Number of hidden scalar channels.
num_blocks : int
Number of Lorentz Transformer blocks.
num_heads : int
Number of attention heads.
nonlinearity : str, optional
Nonlinearity type for MLP layers, by default "gelu".
mlp_ratio : int, optional
Expansion ratio for MLP hidden layers, by default 2.
attn_ratio : int, optional
Expansion ratio for attention hidden layers, by default 1.
num_layers_mlp : int, optional
Number of layers in MLP, by default 2.
dropout_prob : float | None, optional
Dropout probability, by default None.
checkpoint_blocks : bool, optional
Whether to use gradient checkpointing for blocks, by default False.
compile : bool, optional
Whether to compile the model with torch.compile, by default False.
"""
super().__init__()
self.linear_in = Linear(
in_v_channels=in_v_channels,
in_s_channels=in_s_channels,
out_v_channels=hidden_v_channels,
out_s_channels=hidden_s_channels,
)
self.blocks = nn.ModuleList(
[
LGATrSlimBlock(
v_channels=hidden_v_channels,
s_channels=hidden_s_channels,
num_heads=num_heads,
nonlinearity=nonlinearity,
mlp_ratio=mlp_ratio,
attn_ratio=attn_ratio,
num_layers_mlp=num_layers_mlp,
dropout_prob=dropout_prob,
)
for _ in range(num_blocks)
]
)
self.linear_out = Linear(
in_v_channels=hidden_v_channels,
in_s_channels=hidden_s_channels,
out_v_channels=out_v_channels,
out_s_channels=out_s_channels,
)
self._checkpoint_blocks = checkpoint_blocks
self.compile = compile
if compile:
# ugly hack to make torch.compile convenient for users
# the clean solution is model = torch.compile(model, **kwargs) outside of the constructor
# note that we need fullgraph=False because of the torch.compiler.disable for attention
self.__class__ = torch.compile(self.__class__, dynamic=True, mode="default")
[docs]
def forward(self, vectors, scalars, **attn_kwargs):
"""
Parameters
----------
vectors : torch.Tensor
A tensor of shape (..., v_channels, 4) representing Lorentz vectors.
scalars : torch.Tensor
A tensor of shape (..., s_channels) representing scalar features.
**attn_kwargs : dict
Additional keyword arguments for the attention function.
Returns
-------
torch.Tensor, torch.Tensor
Tensors of the same shape as input representing the normalized vectors and scalars.
"""
h_v, h_s = self.linear_in(vectors, scalars)
for block in self.blocks:
if self._checkpoint_blocks:
h_v, h_s = checkpoint(block, h_v, h_s, use_reentrant=False, **attn_kwargs)
else:
h_v, h_s = block(h_v, h_s, **attn_kwargs)
outputs_v, outputs_s = self.linear_out(h_v, h_s)
return outputs_v, outputs_s