lgatr.nets.lgatr_slim.GatedLinearUnit
- class lgatr.nets.lgatr_slim.GatedLinearUnit(in_v_channels, out_v_channels, in_s_channels, out_s_channels, nonlinearity='gelu')[source]
Bases:
ModuleGated linear unit (GLU) for vector and scalar features.
Scalar gates are computed from scalar features, while vector gates are computed from inner products of vector features.
- forward(vectors, scalars)[source]
- Parameters:
vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.
scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.
- Returns:
Tensors of the same shape as input representing the normalized vectors and scalars.
- Return type:
torch.Tensor, torch.Tensor