netket.sampler.MetropolisSamplerNumpy¶
-
class
netket.sampler.
MetropolisSamplerNumpy
(hilbert, rule, *, n_sweeps=None, reset_chain=False, **kwargs)[source]¶ Bases:
netket.sampler.metropolis.MetropolisSampler
Metropolis-Hastings sampler for an Hilbert space according to a specific transition rule executed on CPU through Numpy.
This sampler is equivalent to netket.sampler.MetropolisSampler but instead of executing the whole sampling inside a jax-jitted function, only evaluates the forward pass inside a jax-jitted function, while proposing new steps and accepting/rejecting them is performed in numpy.
Because of Jax dispatch cost, and especially for small system, this sampler performs poorly, while asyntotically it should have the same performance of standard Jax samplers.
However, some transition rules don’t work on GPU, and some samplers (Hamiltonian) work very poorly on jax so this is a good workaround.
See netket.sampler.MetropolisSampler for more informations.
-
__init__
(hilbert, rule, *, n_sweeps=None, reset_chain=False, **kwargs)¶ Constructs a Metropolis Sampler.
- Parameters
hilbert (
AbstractHilbert
) – The hilbert space to samplerule (
MetropolisRule
) – A MetropolisRule to generate random transitions from a given state as well as uniform random states.n_sweeps (
Optional
[int
]) – The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine).reset_chain (
bool
) – If False the state configuration is not resetted when reset() is called.n_chains – The number of Markov Chain to be run in parallel on a single process.
n_chains – The number of batches of the states to sample (default = 8)
machine_pow – The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype – The dtype of the statees sampled (default = np.float32).
- Attributes
-
-
n_batches
¶ The batch size of the configuration $sigma$ used by this sampler.
In general, it is equivalent to
n_chains
.- Return type
-
n_sweeps
: int = 0¶ Number of sweeps for each step along the chain. Defaults to number of sites in hilbert space.
-
reset_chain
: bool = False¶ If True resets the chain state when reset is called (every new sampling).
-
rule
: netket.sampler.metropolis.MetropolisRule = None¶ The metropolis transition rule.
-
- Methods
-
init_state
(machine, parameters, seed=None)¶ Creates the structure holding the state of the sampler.
If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.
If running across several MPI processes, all sampler_states are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds.
The resulting state is guaranteed to be a frozen python dataclass (in particular, a flax’s dataclass), and it can be serialized using Flax serialization methods.
- Parameters
- Return type
SamplerState
- Returns
The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.
-
log_pdf
(model)¶ Returns a closure with the log_pdf function encoded by this sampler.
Note: the result is returned as an HashablePartial so that the closure does not trigger recompilation.
-
replace
(**updates)¶ “Returns a new object replacing the specified fields with new values.
-
reset
(machine, parameters, state=None)¶ Resets the state of the sampler. To be used every time the parameters are changed.
- Parameters
machine (
Union
[Callable
,Module
]) – a Flax module or callable with the forward pass of the log-pdf.parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – The current state of the sampler. If it’s not provided, it will be constructed by callingsampler.init_state(machine, parameters)
with a random seed.
- Return type
SamplerState
- Returns
A valid sampler state.
-
sample
(machine, parameters, *, state=None, chain_length=1)¶ Samples chain_length elements along the chains.
- Parameters
sampler – The Monte Carlo sampler.
machine (
Union
[Callable
,Module
]) – The model or callable to sample from (if it’s a function it should have the signaturef(parameters, σ) -> jnp.ndarray
).parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – current state of the sampler. If None, then initialises it.chain_length (
int
) – (default=1), the length of the chains.
- Returns
The new state of the sampler σ: The next batch of samples.
- Return type
state
-
sample_next
(machine, parameters, state=None)¶ Samples the next state in the markov chain.
- Parameters
machine (
Union
[Callable
,Module
]) – a Flax module or callable apply function with the forward pass of the log-pdf.parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – The current state of the sampler. If it’s not provided, it will be constructed by callingsampler.reset(machine, parameters)
with a random seed.
- Returns
The new state of the sampler σ: The next batch of samples.
- Return type
state
-
-