netket.nn.DenseSymmΒΆ

netket.nn.DenseSymm(symmetries, point_group=None, mode='auto', shape=None, **kwargs)[source]ΒΆ

Implements a projection onto a symmetry group. The output will be equivariant with respect to the symmetry operations in the group and can be averaged to produce an invariant model.

Note: The output shape has changed to seperate the feature and symmetry dimensions. The previous shape was [num_samples, num_symm*features] and the new shape is [num_samples, num_symm, features]

Parameters
  • symmetries – A specification of the symmetry group. Can be given by a nk.graph.Graph, a nk.utils.PermuationGroup, or an array [n_symm, n_sites] specifying the permutations corresponding to symmetry transformations of the lattice.

  • point_group – The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.

  • mode – string β€œfft, matrix, auto” specifying whether to use a fast Fourier transform, matrix multiplication, or to choose a sensible default based on the symmetry group.

  • shape – A tuple specifying the dimensions of the translation group.

  • features – The number of symmetry-reduced features. The full output size is [n_symm,features].

  • use_bias – A bool specifying whether to add a bias to the output (default: True).

  • mask – An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape.

  • dtype – The datatype of the weights. Defaults to a 64bit float.

  • precision – Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details.

  • kernel_init – Optional kernel initialization function. Defaults to variance scaling.

  • bias_init – Optional bias initialization function. Defaults to zero initialization.