lgatr.layers.attention.config.SelfAttentionConfig
- class lgatr.layers.attention.config.SelfAttentionConfig(in_mv_channels=None, out_mv_channels=None, in_s_channels=None, out_s_channels=None, additional_qk_mv_channels=0, additional_qk_s_channels=0, output_init='default', dropout_prob=None, num_heads=8, multi_query=False, increase_hidden_channels=2, head_scale=False)[source]
Bases:
object
Configuration for self-attention.
- Parameters:
num_heads (int) – Number of attention heads.
multi_query (bool) – Whether to do multi-query attention, default is False. Multi-query attention decreases memory consumption and parameter count by using a single set of keys and values for all heads.
increase_hidden_channels (int) – Factor by which to increase the number of hidden channels (both multivectors and scalars). Vanilla transformers use 1, we use 2 for backward compatibility.
head_scale (bool) – Whether to use HeadScaleMHA following the NormFormer, see https://arxiv.org/pdf/2110.09456. Before combining the heads, each head is scaled by a learnable parameter.
- Parameters auto-set by LGATr:
in_mv_channels (int) – Number of input multivector channels.
out_mv_channels (int) – Number of output multivector channels.
in_s_channels (int) – Input scalar channels. If None, no scalars are expected nor returned.
out_s_channels (int) – Output scalar channels. If None, no scalars are expected nor returned.
additional_qk_mv_channels (int) – Whether additional multivector features for the keys and queries will be provided.
additional_qk_s_channels (int) – Whether additional scalar features for the keys and queries will be provided.
output_init (str) – Initialization scheme for final linear layer
dropout_prob (float or None) – Dropout probability
-
additional_qk_mv_channels:
int
= 0
-
additional_qk_s_channels:
int
= 0
-
dropout_prob:
Optional
[float
] = None
-
head_scale:
bool
= False
Returns the number of hidden multivector channels.
Returns the number of hidden scalar channels.
-
in_mv_channels:
Optional
[int
] = None
-
in_s_channels:
Optional
[int
] = None
-
multi_query:
bool
= False
-
num_heads:
int
= 8
-
out_mv_channels:
Optional
[int
] = None
-
out_s_channels:
Optional
[int
] = None
-
output_init:
str
= 'default'