lgatr.nets.lgatr.LGATr
- class lgatr.nets.lgatr.LGATr(num_blocks, in_mv_channels, out_mv_channels, hidden_mv_channels, in_s_channels, out_s_channels, hidden_s_channels, attention, mlp, reinsert_mv_channels=None, reinsert_s_channels=None, dropout_prob=None, checkpoint_blocks=False)[source]
Bases:
Module
L-GATr network.
It combines num_blocks L-GATr transformer blocks, each consisting of geometric self-attention layers, a geometric MLP, residual connections, and normalization layers. In addition, there are initial and final equivariant linear layers.
Assumes input has shape (…, items, in_mv_channels, 16), output has shape (…, items, out_mv_channels, 16), will create hidden representations with shape (…, items, hidden_mv_channels, 16). Similar for extra scalar channels.
- Parameters:
num_blocks (int) – Number of transformer blocks.
in_mv_channels (int) – Number of input multivector channels.
out_mv_channels (int) – Number of output multivector channels.
hidden_mv_channels (int) – Number of hidden multivector channels.
in_s_channels (None or int) – If not None, sets the number of scalar input channels.
out_s_channels (None or int) – If not None, sets the number of scalar output channels.
hidden_s_channels (None or int) – If not None, sets the number of scalar hidden channels.
attention (Dict) – Data for SelfAttentionConfig
mlp (Dict) – Data for MLPConfig
reinsert_mv_channels (None or Tuple[int]) – If not None, specifies multivector channels that will be reinserted in every attention layer.
reinsert_s_channels (None or Tuple[int]) – If not None, specifies scalar channels that will be reinserted in every attention layer.
dropout_prob (float or None) – Dropout probability
checkpoint_blocks (bool) – Whether to use checkpointing for the blocks. If True, will save memory at the cost of speed.
- forward(multivectors, scalars=None, **attn_kwargs)[source]
Forward pass of the network.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, items, in_mv_channels, 16).
scalars (None or torch.Tensor) – Optional input scalars with shape (…, items, in_s_channels).
**attn_kwargs – Optional keyword arguments passed to attention.
- Return type:
Tuple
[Tensor
,Optional
[Tensor
]]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, out_mv_channels, 16).
outputs_s (None or torch.Tensor) – Output scalars with shape (…, items, out_s_channels). None if out_s_channels=None.