lgatr.nets.lgatr.LGATr

class lgatr.nets.lgatr.LGATr(num_blocks, in_mv_channels, out_mv_channels, hidden_mv_channels, in_s_channels, out_s_channels, hidden_s_channels, attention, mlp, reinsert_mv_channels=None, reinsert_s_channels=None, dropout_prob=None, checkpoint_blocks=False)[source]

Bases: Module

L-GATr network.

It combines num_blocks L-GATr transformer blocks, each consisting of geometric self-attention layers, a geometric MLP, residual connections, and normalization layers. In addition, there are initial and final equivariant linear layers.

Assumes input has shape (…, items, in_mv_channels, 16), output has shape (…, items, out_mv_channels, 16), will create hidden representations with shape (…, items, hidden_mv_channels, 16). Similar for extra scalar channels.

Parameters:
  • num_blocks (int) – Number of transformer blocks.

  • in_mv_channels (int) – Number of input multivector channels.

  • out_mv_channels (int) – Number of output multivector channels.

  • hidden_mv_channels (int) – Number of hidden multivector channels.

  • in_s_channels (None or int) – If not None, sets the number of scalar input channels.

  • out_s_channels (None or int) – If not None, sets the number of scalar output channels.

  • hidden_s_channels (None or int) – If not None, sets the number of scalar hidden channels.

  • attention (Dict) – Data for SelfAttentionConfig

  • mlp (Dict) – Data for MLPConfig

  • reinsert_mv_channels (None or Tuple[int]) – If not None, specifies multivector channels that will be reinserted in every attention layer.

  • reinsert_s_channels (None or Tuple[int]) – If not None, specifies scalar channels that will be reinserted in every attention layer.

  • dropout_prob (float or None) – Dropout probability

  • checkpoint_blocks (bool) – Whether to use checkpointing for the blocks. If True, will save memory at the cost of speed.

forward(multivectors, scalars=None, **attn_kwargs)[source]

Forward pass of the network.

Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, items, in_mv_channels, 16).

  • scalars (None or torch.Tensor) – Optional input scalars with shape (…, items, in_s_channels).

  • **attn_kwargs – Optional keyword arguments passed to attention.

Return type:

Tuple[Tensor, Optional[Tensor]]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, out_mv_channels, 16).

  • outputs_s (None or torch.Tensor) – Output scalars with shape (…, items, out_s_channels). None if out_s_channels=None.