lgatr.primitives.attention

Equivariant attention.

Functions

scaled_dot_product_attention(query, key, ...)

Execute scaled dot-product attention.

sdp_attention(q_mv, k_mv, v_mv, q_s, k_s, ...)

Equivariant geometric attention based on scaled dot products.

lgatr.primitives.attention.scaled_dot_product_attention(query, key, value, **attn_kwargs)[source]

Execute scaled dot-product attention. The attention backend is determined dynamically based on the attn_kwargs provided.

Parameters:
  • query (torch.Tensor) – Tensor of shape (…, items_out, channels)

  • key (torch.Tensor) – Tensor of shape (…, items_in, channels)

  • value (torch.Tensor) – Tensor of shape (…, items_in, channels)

  • **attn_kwargs – Optional keyword arguments passed to attention.

Returns:

Tensor of shape (…, head, item_out, channels)

Return type:

torch.Tensor

lgatr.primitives.attention.sdp_attention(q_mv, k_mv, v_mv, q_s, k_s, v_s, **attn_kwargs)[source]

Equivariant geometric attention based on scaled dot products.

Expects both multivector and scalar queries, keys, and values as inputs. Then this function computes multivector and scalar outputs in the following way:

attn_weights[..., i, j] = softmax_j[
    ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :])
    + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :])
]
out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm
out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm
Parameters:
  • q_mv (torch.Tensor) – Multivector queries with shape (…, items_out, mv_channels, 16)

  • k_mv (torch.Tensor) – Multivector keys with shape (…, items_out, mv_channels, 16)

  • v_mv (torch.Tensor) – Multivector values with shape (…, items_out, mv_channels, 16)

  • q_s (torch.Tensor) – Scalar queries with shape (…, items_out, s_channels)

  • k_s (torch.Tensor) – Scalar keys with shape (…, items_out, s_channels)

  • v_s (torch.Tensor) – Scalar values with shape (…, items_out, s_channels)

  • **attn_kwargs – Optional keyword arguments passed to attention.

Return type:

Tuple[Tensor, Tensor]

Returns:

  • outputs_mv (torch.Tensor) – Multivector result with shape (…, items_out, mv_channels, 16)

  • outputs_s (torch.Tensor) – Scalar result with shape (…, items_out, s_channels)