lgatr.layers.attention.config.CrossAttentionConfig

class lgatr.layers.attention.config.CrossAttentionConfig(in_q_mv_channels=None, in_kv_mv_channels=None, out_mv_channels=None, out_s_channels=None, in_q_s_channels=None, in_kv_s_channels=None, additional_q_mv_channels=0, additional_q_s_channels=0, additional_k_mv_channels=0, additional_k_s_channels=0, output_init='default', dropout_prob=None, num_heads=8, multi_query=False, increase_hidden_channels=2, head_scale=False)[source]

Bases: object

Configuration for cross-attention.

Parameters:
  • num_heads (int) – Number of attention heads.

  • multi_query (bool) – Whether to do multi-query attention, default is False. Multi-query attention decreases memory consumption and parameter count by using a single set of keys and values for all heads.

  • increase_hidden_channels (int) – Factor by which to increase the number of hidden channels (both multivectors and scalars). Vanilla transformers use 1, we use 2 for backward compatibility.

  • head_scale (bool) – Whether to use HeadScaleMHA following the NormFormer, see https://arxiv.org/pdf/2110.09456. Before combining the heads, each head is scaled by a learnable parameter.

Parameters auto-set by LGATr:
  • in_q_mv_channels (int) – Number of input query multivector channels.

  • in_kv_mv_channels (int) – Number of input key/value multivector channels.

  • out_mv_channels (int) – Number of output multivector channels.

  • in_q_s_channels (int) – Input query scalar channels. If None, no scalars are expected nor returned.

  • in_kv_s_channels (int) – Input key/value scalar channels. If None, no scalars are expected nor returned.

  • out_s_channels (int) – Output scalar channels. If None, no scalars are expected nor returned.

  • additional_q_mv_channels (int) – Whether additional multivector features for the queries will be provided.

  • additional_q_s_channels (int) – Whether additional scalar features for the queries will be provided.

  • additional_k_mv_channels (int) – Whether additional multivector features for the keys will be provided.

  • additional_k_s_channels (int) – Whether additional scalar features for the keys will be provided.

  • output_init (str) – Initialization scheme for final linear layer

  • dropout_prob (float or None) – Dropout probability

additional_k_mv_channels: int = 0
additional_k_s_channels: int = 0
additional_q_mv_channels: int = 0
additional_q_s_channels: int = 0
classmethod cast(config)[source]

Casts an object as CrossAttentionConfig.

Return type:

CrossAttentionConfig

dropout_prob: Optional[float] = None
head_scale: bool = False
property hidden_mv_channels: int | None

Returns the number of hidden multivector channels.

property hidden_s_channels: int | None

Returns the number of hidden scalar channels.

in_kv_mv_channels: Optional[int] = None
in_kv_s_channels: Optional[int] = None
in_q_mv_channels: Optional[int] = None
in_q_s_channels: Optional[int] = None
increase_hidden_channels: int = 2
multi_query: bool = False
num_heads: int = 8
out_mv_channels: Optional[int] = None
out_s_channels: Optional[int] = None
output_init: str = 'default'