lgatr.nets.conditional_lgatr.ConditionalLGATr
- class lgatr.nets.conditional_lgatr.ConditionalLGATr(num_blocks, in_mv_channels, condition_mv_channels, out_mv_channels, hidden_mv_channels, in_s_channels, condition_s_channels, out_s_channels, hidden_s_channels, attention, crossattention, mlp, dropout_prob=None, checkpoint_blocks=False)[source]
Bases:
Module
Conditional L-GATr network. Assumes that the condition is already preprocessed, e.g. with a non-conditional LGATr network.
It combines num_blocks conditional L-GATr transformer blocks, each consisting of geometric self-attention layers, geometric cross-attention layers, a geometric MLP, residual connections, and normalization layers. In addition, there are initial and final equivariant linear layers.
- Parameters:
num_blocks (int) – Number of transformer blocks.
in_mv_channels (int) – Number of input multivector channels.
condition_mv_channels (int) – Number of condition multivector channels.
out_mv_channels (int) – Number of output multivector channels.
hidden_mv_channels (int) – Number of hidden multivector channels.
in_s_channels (None or int) – If not None, sets the number of scalar input channels.
condition_s_channels (None or int) – If not None, sets the number of scalar condition channels.
out_s_channels (None or int) – If not None, sets the number of scalar output channels.
hidden_s_channels (None or int) – If not None, sets the number of scalar hidden channels.
attention (Dict) – Data for SelfAttentionConfig.
crossattention (Dict) – Data for CrossAttentionConfig.
mlp (Dict) – Data for MLPConfig.
dropout_prob (float or None) – Dropout probability.
checkpoint_blocks (bool) – Whether to use checkpointing for the transformer blocks to save memory.
- forward(multivectors, multivectors_condition, scalars=None, scalars_condition=None, attn_kwargs={}, crossattn_kwargs={})[source]
Forward pass of the network.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, items, in_mv_channels, 16).
multivectors_condition (torch.Tensor) – Input condition multivectors with shape (…, items, in_mv_channels, 16).
scalars (None or torch.Tensor) – Optional input scalars with shape (…, items, in_s_channels).
scalars_condition (None or torch.Tensor) – Optional input scalars with shape (…, items, in_s_channels).
attn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention arguments.
crossattn_kwargs (None or torch.Tensor or AttentionBias) – Optional attention arguments for the condition.
- Return type:
Tuple
[Tensor
,Optional
[Tensor
]]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, out_mv_channels, 16).
outputs_s (None or torch.Tensor) – Output scalars with shape (…, items, out_s_channels). None if out_s_channels=None.