lgatr.layers.linear.EquiLinear
- class lgatr.layers.linear.EquiLinear(in_mv_channels, out_mv_channels, in_s_channels=None, out_s_channels=None, bias=True, initialization='default')[source]
Bases:
Module
Linear layer.
The forward pass maps multivector inputs with shape (…, in_channels, 16) to multivector outputs with shape (…, out_channels, 16) as
outputs[..., j, y] = sum_{i, b, x} weights[j, i, b] basis_map[b, x, y] inputs[..., i, x] = sum_i linear(inputs[..., i, :], weights[j, i, :])
plus an optional bias term for
outputs[..., :, 0]
(biases in other multivector components would break equivariance). Herebasis_map
are precomputed (seegatr.primitives.linear
) and weights are the learnable weights of this layer. Thebasis_map
includes 5 elements if the full Lorentz group is considered, and 10 elements if only the fully-connected subgroup is considered. See theuse_fully_connected_subgroup
parameter inlgatr.primitives.config.LGATrConfig
for details.If there are auxiliary input scalars, they transform under a linear layer, and mix with the scalar components the multivector data. Note that in this layer (and only here) the auxiliary scalars are optional.
This layer supports four initialization schemes:
“default”: preserves (or actually slightly reducing) the variance of the data in the forward pass
“small”: variance of outputs is approximately one order of magnitude smaller than for “default”
“unit_scalar”: outputs will be close to (1, 0, 0, …, 0)
“almost_unit_scalar”: similar to “unit_scalar”, but with more stochasticity
We use the “almost_unit_scalar” initialization to preprocess the second argument in the
GeometricBilinears
layer, and the “small” initialization to combine the different attention heads. All other linear layers in L-GATr use the “default” initialization.- Parameters:
in_mv_channels (int) – Input multivector channels
out_mv_channels (int) – Output multivector channels
bias (bool) – Whether a bias term is added to the scalar component of the multivector outputs
in_s_channels (int or None) – Input scalar channels. If None, no scalars are expected nor returned.
out_s_channels (int or None) – Output scalar channels. If None, no scalars are expected nor returned.
initialization ({"default", "small", "unit_scalar", "almost_unit_scalar"}) – Initialization scheme, see
EquiLinear
description for more information.
- forward(multivectors, scalars=None)[source]
Maps input multivectors and scalars using the most general equivariant linear map.
- Parameters:
multivectors (torch.Tensor) – Input multivectors with shape (…, in_mv_channels, 16)
scalars (None or torch.Tensor) – Optional input scalars with shape (…, in_s_channels)
- Return type:
Tuple
[Tensor
,Optional
[Tensor
]]- Returns:
outputs_mv (torch.Tensor) – Output multivectors with shape (…, out_mv_channels, 16)
outputs_s (None or torch.Tensor) – Output scalars with shape (…, out_s_channels)
- reset_parameters(initialization, gain=1.0, additional_factor=np.float64(0.5773502691896258))[source]
Initializes the weights of the linear layer.
We following the initialization philosophy of Kaiming, ensuring that the variance of the activations during the forward pass is preserved. Our implementation deviates from the torch.nn.Linear default initialization to take the communication between scalar and multivector channels in our linear layer into account. For more information, see inline comments in the code.
- Parameters:
initialization ({"default", "small", "unit_scalar", "almost_unit_scalar"}) – Initialization scheme, see
EquiLinear
description for more information.gain (float) – Gain factor for the activations. Should be 1.0 if previous layer has no activation, sqrt(2) if it has a ReLU activation, and so on. Can be computed with torch.nn.init.calculate_gain().
additional_factor (float) – Empirically, it has been found that slightly decreasing the data variance at each layer gives a better performance. In particular, the PyTorch default initialization uses an additional factor of 1/sqrt(3) (cancelling the factor of sqrt(3) that naturally arises when computing the bounds of a uniform initialization). A discussion of this was (to the best of our knowledge) never published, but see https://github.com/pytorch/pytorch/issues/57109 and https://soumith.ch/files/20141213_gplus_nninit_discussion.htm.
- Return type:
None