lgatr.nets.lgatr_slim.Dropout

class lgatr.nets.lgatr_slim.Dropout(dropout_prob)[source]

Bases: Module

Dropout module for scalar and vector features.

For vector features, the same dropout mask is applied to all four components of each vector.

forward(vectors, scalars)[source]
Parameters:
  • vectors (torch.Tensor) – A tensor of shape (…, v_channels, 4) representing Lorentz vectors.

  • scalars (torch.Tensor) – A tensor of shape (…, s_channels) representing scalar features.

Returns:

Tensors of the same shape as input representing the dropped out vectors and scalars.

Return type:

torch.Tensor, torch.Tensor