lgatr.layers.conditional_lgatr_block.ConditionalLGATrBlock

class lgatr.layers.conditional_lgatr_block.ConditionalLGATrBlock(mv_channels, s_channels, condition_mv_channels, condition_s_channels, attention, crossattention, mlp, dropout_prob=None)[source]

Bases: Module

L-GATr decoder block.

Inputs are first processed by a block consisting of LayerNorm, multi-head geometric self-attention, and a residual connection. Then the conditions are included with cross-attention using the same overhead as in the self-attention part. Then the data is processed by a block consisting of another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another residual connection.

Parameters:
  • mv_channels (int) – Number of input and output multivector channels

  • s_channels (int) – Number of input and output scalar channels

  • condition_mv_channels (int) – Number of condition multivector channels

  • condition_s_channels (int) – Number of condition scalar channels

  • attention (SelfAttentionConfig) – Attention configuration

  • crossattention (CrossAttentionConfig) – Cross-attention configuration

  • mlp (MLPConfig) – MLP configuration

  • dropout_prob (float or None) – Dropout probability

forward(multivectors, multivectors_condition, scalars=None, scalars_condition=None, attn_kwargs={}, crossattn_kwargs={})[source]

Forward pass of the transformer decoder block.

Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, items, mv_channels, 16).

  • scalars (torch.Tensor) – Input scalars with shape (…, items, s_channels).

  • multivectors_condition (torch.Tensor) – Input condition multivectors with shape (…, items, mv_channels, 16).

  • scalars_condition (torch.Tensor) – Input condition scalars with shape (…, items, s_channels).

  • attn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention mask.

  • crossattn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention mask for the condition.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, mv_channels, 16).

  • output_scalars (torch.Tensor) – Output scalars with shape (…, items, s_channels).