lgatr.layers.lgatr_block.LGATrBlock
- class lgatr.layers.lgatr_block.LGATrBlock(mv_channels, s_channels, attention, mlp, dropout_prob=None)[source]
Bases:
Module
L-GATr encoder block.
Inputs are first processed by a block consisting of LayerNorm, multi-head geometric self-attention, and a residual connection. Then the data is processed by a block consisting of another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another residual connection.
- Parameters:
mv_channels (int) – Number of input and output multivector channels
s_channels (int) – Number of input and output scalar channels
attention (SelfAttentionConfig) – Attention configuration
mlp (MLPConfig) – MLP configuration
dropout_prob (float or None) – Dropout probability
- forward(multivectors, scalars, additional_qk_features_mv=None, additional_qk_features_s=None, **attn_kwargs)[source]
Forward pass of the transformer encoder block.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, items, mv_channels, 16).
scalars (torch.Tensor) – Input scalars with shape (…, items, s_channels).
additional_qk_features_mv (None or torch.Tensor) – Additional multivector Q/K features with shape (…, items, add_qk_mv_channels, 16).
additional_qk_features_s (None or torch.Tensor) – Additional scalar Q/K features with shape (…, items, add_qk_s_channels, 16).
**attn_kwargs – Optional keyword arguments passed to attention.
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, mv_channels, 16).
output_scalars (torch.Tensor) – Output scalars with shape (…, items, s_channels).