netket.sampler.MetropolisSampler

class netket.sampler.MetropolisSampler(*args, __skip_preprocess=False, **kwargs)[source]

Bases: netket.sampler.base.Sampler

Metropolis-Hastings sampler for an Hilbert space according to a specific transition rule.

The transition rule is used to generate a proposed state \(s^\prime\), starting from the current state \(s\). The move is accepted with probability

\[A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} F(e^{L(s,s^\prime)}) \right) ,\]

where the probability being sampled from is \(P(s)=|M(s)|^p. Here ::math::\) is a user-provided function (the machine), \(p\) is also user-provided with default value \(p=2\), and \(L(s,s^\prime)\) is a suitable correcting factor computed by the transition kernel.

The dtype of the sampled states can be chosen.

__init__(*args, __skip_preprocess=False, **kwargs)

Constructs a Metropolis Sampler.

Parameters
  • hilbert – The hilbert space to sample

  • rule – A MetropolisRule to generate random transitions from a given state as well as uniform random states.

  • n_sweeps – The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine).

  • reset_chains – If False the state configuration is not resetted when reset() is called.

  • n_chains – The number of Markov Chain to be run in parallel on a single process.

  • machine_pow – The power to which the machine should be exponentiated to generate the pdf (default = 2).

  • dtype – The dtype of the statees sampled (default = np.float32).

Attributes
is_exact

Returns True if the sampler is exact.

The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them.

Return type

bool

machine_pow: int = 2

Exponent of the pdf sampled.

n_batches

The batch size of the configuration $sigma$ used by this sampler.

In general, it is equivalent to n_chains.

Return type

int

n_chains

The total number of chains across all MPI ranks.

If you are not using MPI, this is equal to n_chains_per_rank

Return type

int

n_chains_per_rank: int = None

Number of independent chains on every MPI rank.

n_sweeps: int = None

Number of sweeps for each step along the chain. Defaults to number of sites in hilbert space.

reset_chains: bool = False

If True resets the chain state when reset is called (every new sampling).

rule: netket.sampler.metropolis.MetropolisRule = None

The metropolis transition rule.

Methods
init_state(machine, parameters, seed=None)

Creates the structure holding the state of the sampler.

If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.

If running across several MPI processes, all sampler_states are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds.

The resulting state is guaranteed to be a frozen python dataclass (in particular, a flax’s dataclass), and it can be serialized using Flax serialization methods.

Parameters
  • machine (Union[Callable, Module]) – a Flax module or callable with the forward pass of the log-pdf.

  • parameters (Any) – The PyTree of parameters of the model.

  • seed (Union[int, Any, None]) – An optional seed or jax PRNGKey. If not specified, a random seed will be used.

Return type

SamplerState

Returns

The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.

log_pdf(model)

Returns a closure with the log_pdf function encoded by this sampler.

Note: the result is returned as an HashablePartial so that the closure does not trigger recompilation.

Parameters

model (Union[Callable, Module]) – The machine, or apply_fun

Return type

Callable

Returns

the log probability density function

replace(**updates)

Returns a new object replacing the specified fields with new values.

reset(machine, parameters, state=None)

Resets the state of the sampler. To be used every time the parameters are changed.

Parameters
  • machine (Union[Callable, Module]) – a Flax module or callable with the forward pass of the log-pdf.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If it’s not provided, it will be constructed by calling sampler.init_state(machine, parameters) with a random seed.

Return type

SamplerState

Returns

A valid sampler state.

sample(machine, parameters, *, state=None, chain_length=1)

Samples chain_length elements along the chains.

Parameters
  • sampler – The Monte Carlo sampler.

  • machine (Union[Callable, Module]) – The model or callable to sample from (if it’s a function it should have the signature f(parameters, σ) -> jnp.ndarray).

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – current state of the sampler. If None, then initialises it.

  • chain_length (int) – (default=1), the length of the chains.

Returns

The next batch of samples. state: The new state of the sampler

Return type

σ

sample_next(machine, parameters, state=None)

Samples the next state in the markov chain.

Parameters
  • machine (Union[Callable, Module]) – a Flax module or callable apply function with the forward pass of the log-pdf.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If it’s not provided, it will be constructed by calling sampler.reset(machine, parameters) with a random seed.

Returns

The new state of the sampler σ: The next batch of samples.

Return type

state