lgatr.nets.conditional_lgatr_slim.CrossAttention

class lgatr.nets.conditional_lgatr_slim.CrossAttention(q_v_channels, kv_v_channels, q_s_channels, kv_s_channels, num_heads, attn_ratio=1, dropout_prob=None)[source]

Bases: Module

Cross-attention module for Lorentz vectors and scalar features.

forward(vectors, vectors_condition, scalars, scalars_condition, **attn_kwargs)[source]
Parameters:
  • vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.

  • scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.

  • **attn_kwargs (dict) – Additional keyword arguments for the attention function.

Returns:

Tensors of the same shape as input representing the normalized vectors and scalars.

Return type:

torch.Tensor, torch.Tensor