lgatr.layers.attention.self_attention.SelfAttention

class lgatr.layers.attention.self_attention.SelfAttention(config)[source]

Bases: Module

L-GATr self-attention.

Constructs queries, keys, and values, computes attention, and projects linearly to outputs.

Parameters:

config (SelfAttentionConfig) – Attention configuration.

forward(multivectors, additional_qk_features_mv=None, scalars=None, additional_qk_features_s=None, **attn_kwargs)[source]

Computes self-attention.

The result is the following:

# For each head
queries = linear_channels(inputs)
keys = linear_channels(inputs)
values = linear_channels(inputs)
hidden = attention_items(queries, keys, values, biases=biases)
head_output = linear_channels(hidden)

# Combine results
output = concatenate_heads head_output
Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, items, mv_channels, 16).

  • additional_qk_features_mv (None or torch.Tensor) – Additional multivector Q/K features with shape (…, items, add_qk_mv_channels, 16)

  • scalars (None or torch.Tensor) – Optional input scalars with shape (…, items, num_items, s_channels)

  • additional_qk_features_s (None or torch.Tensor) – Additional scalar Q/K features with shape (…, items, add_qk_mv_channels, 16)

  • scalars – Optional input scalars with shape (…, items, s_channels).

  • **attn_kwargs – Optional keyword arguments passed to attention.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, items, mv_channels, 16).

  • output_scalars (torch.Tensor) – Output scalars with shape (…, items, s_channels).