Using JAX as a backend in NetKet - Feature Preview for v3.0¶
In this tutorial we will show how differentiable functions (for example deep networks) written in JAX can be used as variational quantum states in NetKet.
This feature will be available in the upcoming major release (version 3.0). While version 3.0 is still in beta development, users can already try this feature.
To try out integration with JAX, you first need to fetch the beta version of NetKet v3 We recommend using a virtual environment (either a python environment or a conda environment), for example
python3 -m venv nk_env source nk_env/bin/activate pip install --pre -U netket
Defining the quantum system¶
NetKet allows for full flexibility in defining quantum systems, for example when tackling a ground-state search problem. While there are a few pre-defined hamiltonians, it is relatively straightforward to implement new quantum operators/ Hamiltonians.
In the following, we consider the case of a transverse-field Ising model defined on a graph with random edges.
# ensure we run on the CPU and not on the GPU import os os.environ["JAX_PLATFORM_NAME"] = "cpu" import netket as nk #Define a random graph n_nodes=10 n_edges=20 from numpy.random import choice rand_edges=[choice(n_nodes, size=2,replace=False).tolist() for i in range(n_edges)] graph=nk.graph.Graph(nodes=[i for i in range(n_nodes)], edges=rand_edges) #Define the local hilbert space hi=nk.hilbert.Spin(s=0.5)**graph.n_nodes #Define the Hamiltonian as a sum of local operators from netket.operator import LocalOperator as Op # Pauli Matrices sx = [[0, 1], [1, 0]] sz = [[1, 0], [0, -1]] # Defining the Hamiltonian as a LocalOperator acting on the given Hilbert space ha = Op(hi) #Adding a transverse field term on each node of the graph for i in range(graph.n_nodes): ha += Op(hi, sx, [i]) #Adding nearest-neighbors interactions on the edges of the given graph from numpy import kron J=0.5 for edge in graph.edges(): ha += J*Op(hi, kron(sz, sz), edge)
Defining a JAX module to be used as a wave function¶
We now want to define a suitable JAX wave function to be used as a wave function ansatz. To simplify the discusssion, we consider here a simple single-layer fully connected network with complex weights and a \(tanh\) activation function. These are easy to define in JAX, using for example a model built with STAX. The only requirement is that these networks take as inputs JAX arrays of shape
(batch_size,n), where batch_size
is an arbitrary
batch size and
n is the number of quantum degrees of freedom (for example, the number of spins, in the previous example). Notice that regardless of the dimensionality of the problem, the last dimension is always flattened into a single index.
import jax from jax.experimental import stax #We define a custom layer that performs the sum of its inputs def SumLayer(): def init_fun(rng, input_shape): output_shape = (-1, 1) return output_shape, () def apply_fun(params, inputs, **kwargs): return inputs.sum(axis=-1) return init_fun, apply_fun #We construct a fully connected network with tanh activation model=stax.serial(stax.Dense(2 * graph.n_nodes, W_init=nk.nn.initializers.normal(stddev=0.1, dtype=complex), b_init=nk.nn.initializers.normal(stddev=0.1, dtype=complex)), stax.Tanh,SumLayer()) # Alternatively, we could have used flax, which would have been easier: #class Model(nk.nn.Module): # @nk.nn.compact # def __call__(self, x): # x = nk.nn.Dense(features=2*x.shape[-1], dtype=complex, kernel_init=nk.nn.initializers.normal(stddev=0.01), bias_init=nk.nn.initializers.normal(stddev=0.01))(x) # x = jax.numpy.tanh(x) # return jax.numpy.sum(x, axis=-1) #model = Model() # Alternatively #2 we could have used the built in RBM model: #model = nk.models.RBM(alpha=2, use_visible_bias=False, dtype=np.complex128)
Train the neural network to find an approximate ground state¶
In order to perform Variational Monte Carlo, we further need to specify a suitable sampler (to compute expectation values over the variational state) as well as an optimizer. In the following we will adopt the Stochatic Gradient Descent coupled with quantum natural gradients (this scheme is known in the VMC literature as Stochastic Reconfiguration)
# Defining a sampler that performs local moves # NetKet automatically dispatches here to MCMC sampler written using JAX types sa = nk.sampler.MetropolisLocal(hilbert=hi, n_chains=2)
# Construct the variational state vs = nk.variational.MCState(sa, model, n_samples=1000)
# Using Sgd # Also dispatching to JAX optimizer op = nk.optimizer.Sgd(learning_rate=0.01) # Using Stochastic Reconfiguration a.k.a. quantum natural gradient # Also dispatching to a pure JAX version sr = nk.optimizer.SR(diag_shift=0.01) # Create the Variational Monte Carlo instance to learn the ground state vmc = nk.VMC( hamiltonian=ha, optimizer=op, variational_state=vs, preconditioner=sr )
Running the training loop¶
The last version of NetKet also allows for a finer control of the vmc loop. In the simplest case, one can just iterate through the vmc object and print the current value of the energy. More sophisticated output schemes based on tensorboard have been also implemented, but are not discussed in this Tutorial.
# Running the learning loop and printing the energy every 50 steps # [notice that the very first iteration is slow because of JIT compilation] for it in vmc.iter(500,50): print(it,vmc.energy)
0 6.20-0.02j ± 0.12 [σ²=13.25, R̂=0.9994] 50 -5.98-0.06j ± 0.14 [σ²=9.84, R̂=0.9992] 100 -10.04-0.08j ± 0.11 [σ²=10.24, R̂=0.9995] 150 -10.907+0.014j ± 0.042 [σ²=1.230, R̂=1.0001] 200 -11.261-0.014j ± 0.034 [σ²=0.855, R̂=0.9993] 250 -11.396-0.013j ± 0.024 [σ²=1.402, R̂=0.9993] 300 -11.532+0.015j ± 0.015 [σ²=0.198, R̂=1.0000] 350 -11.727-0.002j ± 0.019 [σ²=0.196, R̂=1.0028] 400 -11.830-0.013j ± 0.017 [σ²=0.105, R̂=1.0051] 450 -11.872+0.011j ± 0.010 [σ²=0.055, R̂=0.9991]
Comparing to exact diagonalization¶
Since this is a relatively small quantum system, we can still diagonalize the Hamiltonian using exact diagonalization. For this purpose, NetKet conveniently exposes a
.to_sparse method that just converts the Hamiltonian into a
scipy sparse matrix. Here we first obtain this sparse matrix, and then diagonalize it with scipy builtins.
import scipy exact_ens=scipy.sparse.linalg.eigsh(ha.to_sparse(),k=1,which='SA',return_eigenvectors=False) print("Exact energy is : ",exact_ens) print("Relative error is : ", (abs((vmc.energy.mean-exact_ens)/exact_ens)))
Exact energy is : -11.932889012463688 Relative error is : 0.0034391959338140226