lgatr.layers.attention.config.CrossAttentionConfig
- class lgatr.layers.attention.config.CrossAttentionConfig(in_q_mv_channels=None, in_kv_mv_channels=None, out_mv_channels=None, out_s_channels=None, in_q_s_channels=None, in_kv_s_channels=None, additional_q_mv_channels=0, additional_q_s_channels=0, additional_k_mv_channels=0, additional_k_s_channels=0, output_init='default', dropout_prob=None, num_heads=8, multi_query=False, increase_hidden_channels=2, head_scale=False)[source]
Bases:
object
Configuration for cross-attention.
- Parameters:
num_heads (int) – Number of attention heads.
multi_query (bool) – Whether to do multi-query attention, default is False. Multi-query attention decreases memory consumption and parameter count by using a single set of keys and values for all heads.
increase_hidden_channels (int) – Factor by which to increase the number of hidden channels (both multivectors and scalars). Vanilla transformers use 1, we use 2 for backward compatibility.
head_scale (bool) – Whether to use HeadScaleMHA following the NormFormer, see https://arxiv.org/pdf/2110.09456. Before combining the heads, each head is scaled by a learnable parameter.
- Parameters auto-set by LGATr:
in_q_mv_channels (int) – Number of input query multivector channels.
in_kv_mv_channels (int) – Number of input key/value multivector channels.
out_mv_channels (int) – Number of output multivector channels.
in_q_s_channels (int) – Input query scalar channels. If None, no scalars are expected nor returned.
in_kv_s_channels (int) – Input key/value scalar channels. If None, no scalars are expected nor returned.
out_s_channels (int) – Output scalar channels. If None, no scalars are expected nor returned.
additional_q_mv_channels (int) – Whether additional multivector features for the queries will be provided.
additional_q_s_channels (int) – Whether additional scalar features for the queries will be provided.
additional_k_mv_channels (int) – Whether additional multivector features for the keys will be provided.
additional_k_s_channels (int) – Whether additional scalar features for the keys will be provided.
output_init (str) – Initialization scheme for final linear layer
dropout_prob (float or None) – Dropout probability
-
additional_k_mv_channels:
int
= 0
-
additional_k_s_channels:
int
= 0
-
additional_q_mv_channels:
int
= 0
-
additional_q_s_channels:
int
= 0
-
dropout_prob:
Optional
[float
] = None
-
head_scale:
bool
= False
Returns the number of hidden multivector channels.
Returns the number of hidden scalar channels.
-
in_kv_mv_channels:
Optional
[int
] = None
-
in_kv_s_channels:
Optional
[int
] = None
-
in_q_mv_channels:
Optional
[int
] = None
-
in_q_s_channels:
Optional
[int
] = None
-
multi_query:
bool
= False
-
num_heads:
int
= 8
-
out_mv_channels:
Optional
[int
] = None
-
out_s_channels:
Optional
[int
] = None
-
output_init:
str
= 'default'