lgatr.nets.conditional_lgatr.ConditionalLGATr

class lgatr.nets.conditional_lgatr.ConditionalLGATr(num_blocks, in_mv_channels, condition_mv_channels, out_mv_channels, hidden_mv_channels, in_s_channels, condition_s_channels, out_s_channels, hidden_s_channels, attention, crossattention, mlp, dropout_prob=None, checkpoint_blocks=False)[source]

Bases: Module

Conditional L-GATr network. Assumes that the condition is already preprocessed, e.g. with a non-conditional LGATr network.

It combines num_blocks conditional L-GATr transformer blocks, each consisting of geometric self-attention layers, geometric cross-attention layers, a geometric MLP, residual connections, and normalization layers. In addition, there are initial and final equivariant linear layers.

Parameters:
  • num_blocks (int) – Number of transformer blocks.

  • in_mv_channels (int) – Number of input multivector channels.

  • condition_mv_channels (int) – Number of condition multivector channels.

  • out_mv_channels (int) – Number of output multivector channels.

  • hidden_mv_channels (int) – Number of hidden multivector channels.

  • in_s_channels (None or int) – If not None, sets the number of scalar input channels.

  • condition_s_channels (None or int) – If not None, sets the number of scalar condition channels.

  • out_s_channels (None or int) – If not None, sets the number of scalar output channels.

  • hidden_s_channels (None or int) – If not None, sets the number of scalar hidden channels.

  • attention (Dict) – Data for SelfAttentionConfig.

  • crossattention (Dict) – Data for CrossAttentionConfig.

  • mlp (Dict) – Data for MLPConfig.

  • dropout_prob (float or None) – Dropout probability.

  • checkpoint_blocks (bool) – Whether to use checkpointing for the transformer blocks to save memory.

forward(multivectors, multivectors_condition, scalars=None, scalars_condition=None, attn_kwargs={}, crossattn_kwargs={})[source]

Forward pass of the network.

Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, items, in_mv_channels, 16).

  • multivectors_condition (torch.Tensor) – Input condition multivectors with shape (…, items, in_mv_channels, 16).

  • scalars (None or torch.Tensor) – Optional input scalars with shape (…, items, in_s_channels).

  • scalars_condition (None or torch.Tensor) – Optional input scalars with shape (…, items, in_s_channels).

  • attn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention arguments.

  • crossattn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention arguments for the condition.

Return type:

Tuple[Tensor, Optional[Tensor]]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, out_mv_channels, 16).

  • outputs_s (None or torch.Tensor) – Output scalars with shape (…, items, out_s_channels). None if out_s_channels=None.