netket.nn.DenseEquivariantΒΆ

netket.nn.DenseEquivariant(symmetries, mode='auto', shape=None, point_group=None, **kwargs)[source]ΒΆ

A group convolution operation that is equivariant over a symmetry group.

Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm] and returns a feature map of poses of shape [num_samples, out_features, num_symm]

G-convolutions are described in ` Cohen et. {it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in ` Roth et. {it al} <https://arxiv.org/pdf/2104.05085.pdf>`_

The G-convolution generalizes the convolution to non-commuting groups:

\[C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h\]

Group elements that differ by the same symmetry operation (i.e. \(g = xh\) and \(g' = xh'\)) are connected by the same filter.

Parameters
  • symmetries – A specification of the symmetry group. Can be given by a nk.graph.Graph, an nk.utils.PermuationGroup, a list of irreducible representations or a product table.

  • point_group – The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.

  • mode – string β€œfft, irreps, matrix, auto” specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix.

  • shape – A tuple specifying the dimensions of the translation group.

  • in_features – The number of symmetry-reduced features. The full input size is n_symm*in_features.

  • out_features – The number of symmetry-reduced features. The full output size is n_symm*out_features.

  • use_bias – A bool specifying whether to add a bias to the output (default: True).

  • mask – An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape.

  • dtype – The datatype of the weights. Defaults to a 64bit float.

  • precision – Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details.

  • kernel_init – Optional kernel initialization function. Defaults to variance scaling.

  • bias_init – Optional bias initialization function. Defaults to zero initialization.