lgatr.primitives.attention
Equivariant attention.
Functions
|
Execute scaled dot-product attention. |
|
Equivariant geometric attention based on scaled dot products. |
- lgatr.primitives.attention.scaled_dot_product_attention(query, key, value, **attn_kwargs)[source]
Execute scaled dot-product attention. The attention backend is determined dynamically based on the
attn_kwargs
provided.- Parameters:
query (torch.Tensor) – Tensor of shape (…, items_out, channels)
key (torch.Tensor) – Tensor of shape (…, items_in, channels)
value (torch.Tensor) – Tensor of shape (…, items_in, channels)
**attn_kwargs – Optional keyword arguments passed to attention.
- Returns:
Tensor of shape (…, head, item_out, channels)
- Return type:
torch.Tensor
- lgatr.primitives.attention.sdp_attention(q_mv, k_mv, v_mv, q_s, k_s, v_s, **attn_kwargs)[source]
Equivariant geometric attention based on scaled dot products.
Expects both multivector and scalar queries, keys, and values as inputs. Then this function computes multivector and scalar outputs in the following way:
attn_weights[..., i, j] = softmax_j[ ga_inner_product(q_mv[..., i, :, :], k_mv[..., j, :, :]) + euclidean_inner_product(q_s[..., i, :], k_s[..., j, :]) ] out_mv[..., i, c, :] = sum_j attn_weights[..., i, j] v_mv[..., j, c, :] / norm out_s[..., i, c] = sum_j attn_weights[..., i, j] v_s[..., j, c] / norm
- Parameters:
q_mv (torch.Tensor) – Multivector queries with shape (…, items_out, mv_channels, 16)
k_mv (torch.Tensor) – Multivector keys with shape (…, items_out, mv_channels, 16)
v_mv (torch.Tensor) – Multivector values with shape (…, items_out, mv_channels, 16)
q_s (torch.Tensor) – Scalar queries with shape (…, items_out, s_channels)
k_s (torch.Tensor) – Scalar keys with shape (…, items_out, s_channels)
v_s (torch.Tensor) – Scalar values with shape (…, items_out, s_channels)
**attn_kwargs – Optional keyword arguments passed to attention.
- Return type:
Tuple
[Tensor
,Tensor
]- Returns:
outputs_mv (torch.Tensor) – Multivector result with shape (…, items_out, mv_channels, 16)
outputs_s (torch.Tensor) – Scalar result with shape (…, items_out, s_channels)