lgatr.layers.conditional_lgatr_block.ConditionalLGATrBlock
- class lgatr.layers.conditional_lgatr_block.ConditionalLGATrBlock(mv_channels, s_channels, condition_mv_channels, condition_s_channels, attention, crossattention, mlp, dropout_prob=None)[source]
Bases:
Module
L-GATr decoder block.
Inputs are first processed by a block consisting of LayerNorm, multi-head geometric self-attention, and a residual connection. Then the conditions are included with cross-attention using the same overhead as in the self-attention part. Then the data is processed by a block consisting of another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another residual connection.
- Parameters:
mv_channels (int) – Number of input and output multivector channels
s_channels (int) – Number of input and output scalar channels
condition_mv_channels (int) – Number of condition multivector channels
condition_s_channels (int) – Number of condition scalar channels
attention (SelfAttentionConfig) – Attention configuration
crossattention (CrossAttentionConfig) – Cross-attention configuration
mlp (MLPConfig) – MLP configuration
dropout_prob (float or None) – Dropout probability
- forward(multivectors, multivectors_condition, scalars=None, scalars_condition=None, attn_kwargs={}, crossattn_kwargs={})[source]
Forward pass of the transformer decoder block.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, items, mv_channels, 16).
scalars (torch.Tensor) – Input scalars with shape (…, items, s_channels).
multivectors_condition (torch.Tensor) – Input condition multivectors with shape (…, items, mv_channels, 16).
scalars_condition (torch.Tensor) – Input condition scalars with shape (…, items, s_channels).
attn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention mask.
crossattn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention mask for the condition.
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, mv_channels, 16).
output_scalars (torch.Tensor) – Output scalars with shape (…, items, s_channels).