lgatr.layers.linear.EquiLinear

class lgatr.layers.linear.EquiLinear(in_mv_channels, out_mv_channels, in_s_channels=None, out_s_channels=None, bias=True, initialization='default')[source]

Bases: Module

Linear layer.

The forward pass maps multivector inputs with shape (…, in_channels, 16) to multivector outputs with shape (…, out_channels, 16) as

outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x]
= sum_i linear(inputs[..., i, :], weights[j, i, :])

plus an optional bias term for outputs[..., :, 0] (biases in other multivector components would break equivariance). Here basis_map are precomputed (see gatr.primitives.linear) and weights are the learnable weights of this layer. The basis_map includes 5 elements if the full Lorentz group is considered, and 10 elements if only the fully-connected subgroup is considered. See the use_fully_connected_subgroup parameter in lgatr.primitives.config.LGATrConfig for details.

If there are auxiliary input scalars, they transform under a linear layer, and mix with the scalar components the multivector data. Note that in this layer (and only here) the auxiliary scalars are optional.

This layer supports four initialization schemes:

  • “default”: preserves (or actually slightly reducing) the variance of the data in the forward pass

  • “small”: variance of outputs is approximately one order of magnitude smaller than for “default”

  • “unit_scalar”: outputs will be close to (1, 0, 0, …, 0)

  • “almost_unit_scalar”: similar to “unit_scalar”, but with more stochasticity

We use the “almost_unit_scalar” initialization to preprocess the second argument in the GeometricBilinears layer, and the “small” initialization to combine the different attention heads. All other linear layers in L-GATr use the “default” initialization.

Parameters:
  • in_mv_channels (int) – Input multivector channels

  • out_mv_channels (int) – Output multivector channels

  • bias (bool) – Whether a bias term is added to the scalar component of the multivector outputs

  • in_s_channels (int or None) – Input scalar channels. If None, no scalars are expected nor returned.

  • out_s_channels (int or None) – Output scalar channels. If None, no scalars are expected nor returned.

  • initialization ({"default", "small", "unit_scalar", "almost_unit_scalar"}) – Initialization scheme, see EquiLinear description for more information.

forward(multivectors, scalars=None)[source]

Maps input multivectors and scalars using the most general equivariant linear map.

Parameters:
  • multivectors (torch.Tensor) – Input multivectors with shape (…, in_mv_channels, 16)

  • scalars (None or torch.Tensor) – Optional input scalars with shape (…, in_s_channels)

Return type:

Tuple[Tensor, Optional[Tensor]]

Returns:

  • outputs_mv (torch.Tensor) – Output multivectors with shape (…, out_mv_channels, 16)

  • outputs_s (None or torch.Tensor) – Output scalars with shape (…, out_s_channels)

reset_parameters(initialization, gain=1.0, additional_factor=np.float64(0.5773502691896258))[source]

Initializes the weights of the linear layer.

We following the initialization philosophy of Kaiming, ensuring that the variance of the activations during the forward pass is preserved. Our implementation deviates from the torch.nn.Linear default initialization to take the communication between scalar and multivector channels in our linear layer into account. For more information, see inline comments in the code.

Parameters:
  • initialization ({"default", "small", "unit_scalar", "almost_unit_scalar"}) – Initialization scheme, see EquiLinear description for more information.

  • gain (float) – Gain factor for the activations. Should be 1.0 if previous layer has no activation, sqrt(2) if it has a ReLU activation, and so on. Can be computed with torch.nn.init.calculate_gain().

  • additional_factor (float) – Empirically, it has been found that slightly decreasing the data variance at each layer gives a better performance. In particular, the PyTorch default initialization uses an additional factor of 1/sqrt(3) (cancelling the factor of sqrt(3) that naturally arises when computing the bounds of a uniform initialization). A discussion of this was (to the best of our knowledge) never published, but see https://github.com/pytorch/pytorch/issues/57109 and https://soumith.ch/files/20141213_gplus_nninit_discussion.htm.

Return type:

None