lgatr.primitives.dropout

Grade dropout.

Functions

grade_dropout(x, p[, training])

Multivector dropout, dropping out grades independently.

lgatr.primitives.dropout.grade_dropout(x, p, training=True)[source]

Multivector dropout, dropping out grades independently.

Parameters:
  • x (torch.Tensor) – Input data with shape (…, 16).

  • p (float) – Dropout probability (assumed the same for each grade).

  • training (bool) – Switches between train-time and test-time behaviour.

Returns:

outputs – Inputs with dropout applied, shape (…, 16).

Return type:

torch.Tensor