lgatr.nets.lgatr_slim.Linear

class lgatr.nets.lgatr_slim.Linear(in_v_channels, out_v_channels, in_s_channels, out_s_channels, bias=True, initialization='default')[source]

Bases: Module

Linear operations for vector and scalar features.

Supports optional mixing between vector and scalar features to improve expressivity.

forward(vectors, scalars)[source]
Parameters:
  • vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.

  • scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.

Returns:

Tensors of the same shape as input representing the normalized vectors and scalars.

Return type:

torch.Tensor, torch.Tensor

reset_parameters(initialization, additional_factor=1.0)[source]