DenseSymm(symmetries, point_group=None, mode='auto', shape=None, **kwargs)¶
Implements a projection onto a symmetry group. The output will be equivariant with respect to the symmetry operations in the group and can be averaged to produce an invariant model.
Note: The output shape has changed to seperate the feature and symmetry dimensions. The previous shape was [num_samples, num_symm*features] and the new shape is [num_samples, num_symm, features]
symmetries – A specification of the symmetry group. Can be given by a nk.graph.Graph, a nk.utils.PermuationGroup, or an array [n_symm, n_sites] specifying the permutations corresponding to symmetry transformations of the lattice.
point_group – The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.
mode – string “fft, matrix, auto” specifying whether to use a fast Fourier transform, matrix multiplication, or to choose a sensible default based on the symmetry group.
shape – A tuple specifying the dimensions of the translation group.
features – The number of symmetry-reduced features. The full output size is [n_symm,features].
use_bias – A bool specifying whether to add a bias to the output (default: True).
mask – An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape.
dtype – The datatype of the weights. Defaults to a 64bit float.
precision – Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details.
kernel_init – Optional kernel initialization function. Defaults to variance scaling.
bias_init – Optional bias initialization function. Defaults to zero initialization.