lgatr.nets.lgatr_slim.LGATrSlim
- class lgatr.nets.lgatr_slim.LGATrSlim(in_v_channels, out_v_channels, hidden_v_channels, in_s_channels, out_s_channels, hidden_s_channels, num_blocks, num_heads, nonlinearity='gelu', mlp_ratio=2, attn_ratio=1, num_layers_mlp=2, dropout_prob=None, checkpoint_blocks=False, compile=False)[source]
Bases:
ModuleL-GATr-slim network.
- forward(vectors, scalars, **attn_kwargs)[source]
- Parameters:
vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.
scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.
**attn_kwargs (dict) – Additional keyword arguments for the attention function.
- Returns:
Tensors of the same shape as input representing the normalized vectors and scalars.
- Return type:
torch.Tensor, torch.Tensor