lgatr.nets.conditional_lgatr_slim.ConditionalLGATrSlimBlock
- class lgatr.nets.conditional_lgatr_slim.ConditionalLGATrSlimBlock(v_channels, condition_v_channels, s_channels, condition_s_channels, num_heads, nonlinearity='gelu', mlp_ratio=2, attn_ratio=1, num_layers_mlp=2, dropout_prob=None)[source]
Bases:
ModuleA single block of the conditional L-GATr-slim, consisting of self-attention, cross-attention and MLP layers, pre-norm and residual connections.
- forward(vectors, vectors_condition, scalars, scalars_condition, attn_kwargs=None, crossattn_kwargs=None)[source]
- Parameters:
vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.
vectors_condition (torch.Tensor) – A tensor of shape (…, condition_v_channels, 4) representing a Lorentz vector condition included in cross-attention.
scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.
scalars_condition (torch.Tensor) – A tensor of shape (…, condition_s_channels) representing a scalar condition included in cross-attention.
**attn_kwargs (dict) – Additional keyword arguments for the attention function.
**crossattn_kwargs (dict) – Additional keyword arguments for the cross-attention function.
- Returns:
Tensors of the same shape as input representing the normalized vectors and scalars.
- Return type:
torch.Tensor, torch.Tensor