lgatr.layers.attention.cross_attention.CrossAttention

class lgatr.layers.attention.cross_attention.CrossAttention(config)[source]

Bases: Module

L-GATr cross-attention.

Constructs queries, keys, and values, computes attention, and projects linearly to outputs.

Parameters:

config (SelfAttentionConfig) – Attention configuration.

forward(multivectors_kv, multivectors_q, scalars_kv=None, scalars_q=None, **attn_kwargs)[source]

Compute cross attention.

The result is the following:

# For each head
queries = linear_channels(inputs_q)
keys = linear_channels(inputs_kv)
values = linear_channels(inputs_kv)
hidden = attention_items(queries, keys, values, biases=biases)
head_output = linear_channels(hidden)

# Combine results
output = concatenate_heads head_output
Parameters:
  • multivectors_kv (torch.Tensor) – Input multivectors for key and value with shape (…, items_kv, mv_channels, 16).

  • multivectors_q (torch.Tensor) – Input multivectors for query with shape (…, items_q, mv_channels, 16).

  • scalars_kv (None or torch.Tensor) – Optional input scalars for key and value with shape (…, items_kv, s_channels)

  • scalars_q (None or torch.Tensor) – Optional input scalars for query with shape (…, items_q, s_channels)

  • **attn_kwargs – Optional keyword arguments passed to attention.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, items_q, mv_channels, 16).

  • output_scalars (torch.Tensor) – Output scalars with shape (…, items_q, s_channels).