netket.models.FastARNNConv1D

class netket.models.FastARNNConv1D(hilbert, layers, features, kernel_size, kernel_dilation=1, activation=<function selu>, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: netket.models.autoreg.AbstractARNN

Fast autoregressive neural network with 1D convolution layers.

See netket.nn.FastMaskedConv1D for a brief explanation of fast autoregressive sampling.

Attributes
kernel_dilation: int = 1

1).

Type

dilation factor of the convolution kernel (default

precision: Any = None

numerical precision of the computation, see jax.lax.Precision for details.

use_bias: bool = True

True).

Type

whether to add a bias to the output (default

variables

Returns the variables in this module.

Return type

Mapping[str, Mapping[str, Any]]

Methods
activation()

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Parameters

x (Any) – input array

Return type

Any

bias_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)
conditionals(inputs)[source]

Computes the conditional probabilities for each site to take each value.

Parameters

inputs (Union[ndarray, DeviceArray, Tracer]) – configurations with dimensions (batch, Hilbert.size).

Return type

Union[ndarray, DeviceArray, Tracer]

Returns

The probabilities with dimensions (batch, Hilbert.size, Hilbert.local_size).

Examples

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> p = model.apply(variables, σ, method=model.conditionals)
>>> print(p[2, 3, :])
[0.3 0.7]
# For the 3rd spin of the 2nd sample in the batch,
# it takes probability 0.3 to be spin down (local state index 0),
# and probability 0.7 to be spin up (local state index 1).
kernel_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)