lgatr.layers.attention.self_attention.SelfAttention
- class lgatr.layers.attention.self_attention.SelfAttention(config)[source]
Bases:
Module
L-GATr self-attention.
Constructs queries, keys, and values, computes attention, and projects linearly to outputs.
- Parameters:
config (SelfAttentionConfig) – Attention configuration.
- forward(multivectors, additional_qk_features_mv=None, scalars=None, additional_qk_features_s=None, **attn_kwargs)[source]
Computes self-attention.
The result is the following:
# For each head queries = linear_channels(inputs) keys = linear_channels(inputs) values = linear_channels(inputs) hidden = attention_items(queries, keys, values, biases=biases) head_output = linear_channels(hidden) # Combine results output = concatenate_heads head_output
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, items, mv_channels, 16).
additional_qk_features_mv (None or torch.Tensor) – Additional multivector Q/K features with shape (…, items, add_qk_mv_channels, 16)
scalars (None or torch.Tensor) – Optional input scalars with shape (…, items, num_items, s_channels)
additional_qk_features_s (None or torch.Tensor) – Additional scalar Q/K features with shape (…, items, add_qk_mv_channels, 16)
scalars – Optional input scalars with shape (…, items, s_channels).
**attn_kwargs – Optional keyword arguments passed to attention.
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, mv_channels, 16).
output_scalars (torch.Tensor) – Output scalars with shape (…, items, s_channels).