lgatr.nets.conditional_lgatr_slim.ConditionalLGATrSlim

class lgatr.nets.conditional_lgatr_slim.ConditionalLGATrSlim(in_v_channels, condition_v_channels, out_v_channels, hidden_v_channels, in_s_channels, condition_s_channels, out_s_channels, hidden_s_channels, num_blocks, num_heads, nonlinearity='gelu', mlp_ratio=2, attn_ratio=1, num_layers_mlp=2, dropout_prob=None, checkpoint_blocks=False, compile=False)[source]

Bases: Module

Conditional L-GATr-slim network.

forward(vectors, vectors_condition, scalars, scalars_condition, attn_kwargs=None, crossattn_kwargs=None)[source]
Parameters:
  • vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.

  • vectors_condition (torch.Tensor) – A tensor of shape (…, v_channels_condition, 4) representing a Lorentz vector condition included in cross-attention.

  • scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.

  • scalars_condition (torch.Tensor) – A tensor of shape (…, s_channels_condition) representing a scalar condition included in cross-attention.

  • attn_kwargs (dict) – Additional keyword arguments for the self-attention function.

  • crossattn_kwargs (dict) – Additional keyword arguments for the cross-attention function.

Returns:

Tensors of the same shape as input representing the normalized vectors and scalars.

Return type:

torch.Tensor, torch.Tensor