lgatr.layers.lgatr_block.LGATrBlock

class lgatr.layers.lgatr_block.LGATrBlock(mv_channels, s_channels, attention, mlp, dropout_prob=None)[source]

Bases: Module

L-GATr encoder block.

Inputs are first processed by a block consisting of LayerNorm, multi-head geometric self-attention, and a residual connection. Then the data is processed by a block consisting of another LayerNorm, an item-wise two-layer geometric MLP with GeLU activations, and another residual connection.

Parameters:
  • mv_channels (int) – Number of input and output multivector channels

  • s_channels (int) – Number of input and output scalar channels

  • attention (SelfAttentionConfig) – Attention configuration

  • mlp (MLPConfig) – MLP configuration

  • dropout_prob (float or None) – Dropout probability

forward(multivectors, scalars, additional_qk_features_mv=None, additional_qk_features_s=None, **attn_kwargs)[source]

Forward pass of the transformer encoder block.

Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, items, mv_channels, 16).

  • scalars (torch.Tensor) – Input scalars with shape (…, items, s_channels).

  • additional_qk_features_mv (None or torch.Tensor) – Additional multivector Q/K features with shape (…, items, add_qk_mv_channels, 16).

  • additional_qk_features_s (None or torch.Tensor) – Additional scalar Q/K features with shape (…, items, add_qk_s_channels, 16).

  • **attn_kwargs – Optional keyword arguments passed to attention.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, mv_channels, 16).

  • output_scalars (torch.Tensor) – Output scalars with shape (…, items, s_channels).