lgatr.layers.attention.cross_attention.CrossAttention
- class lgatr.layers.attention.cross_attention.CrossAttention(config)[source]
Bases:
Module
L-GATr cross-attention.
Constructs queries, keys, and values, computes attention, and projects linearly to outputs.
- Parameters:
config (SelfAttentionConfig) – Attention configuration.
- forward(multivectors_kv, multivectors_q, scalars_kv=None, scalars_q=None, **attn_kwargs)[source]
Compute cross attention.
The result is the following:
# For each head queries = linear_channels(inputs_q) keys = linear_channels(inputs_kv) values = linear_channels(inputs_kv) hidden = attention_items(queries, keys, values, biases=biases) head_output = linear_channels(hidden) # Combine results output = concatenate_heads head_output
- Parameters:
multivectors_kv (torch.Tensor) – Input multivectors for key and value with shape (…, items_kv, mv_channels, 16).
multivectors_q (torch.Tensor) – Input multivectors for query with shape (…, items_q, mv_channels, 16).
scalars_kv (None or torch.Tensor) – Optional input scalars for key and value with shape (…, items_kv, s_channels)
scalars_q (None or torch.Tensor) – Optional input scalars for query with shape (…, items_q, s_channels)
**attn_kwargs – Optional keyword arguments passed to attention.
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, items_q, mv_channels, 16).
output_scalars (torch.Tensor) – Output scalars with shape (…, items_q, s_channels).