Using JAX as a backend in NetKet - Feature Preview for v3.0

In this tutorial we will show how differentiable functions (for example deep networks) written in JAX can be used as variational quantum states in NetKet.

This feature will be available in the upcoming major release (version 3.0). While version 3.0 is still in beta development, users can already try this feature.


To try out integration with JAX, you first need to fetch the beta version of NetKet v3 We recommend using a virtual environment (either a python environment or a conda environment), for example

python3 -m venv nk_env
source nk_env/bin/activate
pip install --pre -U netket

Defining the quantum system

NetKet allows for full flexibility in defining quantum systems, for example when tackling a ground-state search problem. While there are a few pre-defined hamiltonians, it is relatively straightforward to implement new quantum operators/ Hamiltonians.

In the following, we consider the case of a transverse-field Ising model defined on a graph with random edges.

\[H = -\sum_{i\in\textrm{nodes}}^{L} \sigma^x_{i} + J \sum_{(i,j)\in\textrm{edges}}\sigma_{i}^{z}\sigma_{j}^{z}\]
# ensure we run on the CPU and not on the GPU
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import netket as nk

#Define a random graph
from numpy.random import choice
rand_edges=[choice(n_nodes, size=2,replace=False).tolist() for i in range(n_edges)]

graph=nk.graph.Graph(nodes=[i for i in range(n_nodes)], edges=rand_edges)

#Define the local hilbert space

#Define the Hamiltonian as a sum of local operators
from netket.operator import LocalOperator as Op

# Pauli Matrices
sx = [[0, 1], [1, 0]]
sz = [[1, 0], [0, -1]]

# Defining the Hamiltonian as a LocalOperator acting on the given Hilbert space
ha = Op(hi)

#Adding a transverse field term on each node of the graph
for i in range(graph.n_nodes):
    ha += Op(hi, sx, [i])

#Adding nearest-neighbors interactions on the edges of the given graph
from numpy import kron
for edge in graph.edges():
    ha += J*Op(hi, kron(sz, sz), edge)
[ ]:

Defining a JAX module to be used as a wave function

We now want to define a suitable JAX wave function to be used as a wave function ansatz. To simplify the discusssion, we consider here a simple single-layer fully connected network with complex weights and a \(tanh\) activation function. These are easy to define in JAX, using for example a model built with STAX. The only requirement is that these networks take as inputs JAX arrays of shape (batch_size,n), where batch_size is an arbitrary batch size and n is the number of quantum degrees of freedom (for example, the number of spins, in the previous example). Notice that regardless of the dimensionality of the problem, the last dimension is always flattened into a single index.

import jax
from jax.experimental import stax

#We define a custom layer that performs the sum of its inputs
def SumLayer():
    def init_fun(rng, input_shape):
        output_shape = (-1, 1)
        return output_shape, ()

    def apply_fun(params, inputs, **kwargs):
        return inputs.sum(axis=-1)

    return init_fun, apply_fun

#We construct a fully connected network with tanh activation
model=stax.serial(stax.Dense(2 * graph.n_nodes, W_init=nk.nn.initializers.normal(stddev=0.1, dtype=complex),
                             b_init=nk.nn.initializers.normal(stddev=0.1, dtype=complex)),

# Alternatively, we could have used flax, which would have been easier:
#class Model(nk.nn.Module):
#    @nk.nn.compact
#    def __call__(self, x):
#        x = nk.nn.Dense(features=2*x.shape[-1], dtype=complex, kernel_init=nk.nn.initializers.normal(stddev=0.01), bias_init=nk.nn.initializers.normal(stddev=0.01))(x)
#        x = jax.numpy.tanh(x)
#        return jax.numpy.sum(x, axis=-1)
#model = Model()

# Alternatively #2 we could have used the built in RBM model:
#model = nk.models.RBM(alpha=2, use_visible_bias=False, dtype=np.complex128)

Train the neural network to find an approximate ground state

In order to perform Variational Monte Carlo, we further need to specify a suitable sampler (to compute expectation values over the variational state) as well as an optimizer. In the following we will adopt the Stochatic Gradient Descent coupled with quantum natural gradients (this scheme is known in the VMC literature as Stochastic Reconfiguration)

# Defining a sampler that performs local moves
# NetKet automatically dispatches here to MCMC sampler written using JAX types
sa = nk.sampler.MetropolisLocal(hilbert=hi, n_chains=2)
# Construct the variational state
vs = nk.variational.MCState(sa, model, n_samples=1000)
# Using Sgd
# Also dispatching to JAX optimizer
op = nk.optimizer.Sgd(learning_rate=0.01)

# Using Stochastic Reconfiguration a.k.a. quantum natural gradient
# Also dispatching to a pure JAX version
sr = nk.optimizer.SR(diag_shift=0.01)

# Create the Variational Monte Carlo instance to learn the ground state
vmc = nk.VMC(
    hamiltonian=ha, optimizer=op, variational_state=vs, preconditioner=sr
[ ]:

Running the training loop

The last version of NetKet also allows for a finer control of the vmc loop. In the simplest case, one can just iterate through the vmc object and print the current value of the energy. More sophisticated output schemes based on tensorboard have been also implemented, but are not discussed in this Tutorial.

# Running the learning loop and printing the energy every 50 steps
# [notice that the very first iteration is slow because of JIT compilation]
for it in vmc.iter(500,50):
0 6.20-0.02j ± 0.12 [σ²=13.25, R̂=0.9994]
50 -5.98-0.06j ± 0.14 [σ²=9.84, R̂=0.9992]
100 -10.04-0.08j ± 0.11 [σ²=10.24, R̂=0.9995]
150 -10.907+0.014j ± 0.042 [σ²=1.230, R̂=1.0001]
200 -11.261-0.014j ± 0.034 [σ²=0.855, R̂=0.9993]
250 -11.396-0.013j ± 0.024 [σ²=1.402, R̂=0.9993]
300 -11.532+0.015j ± 0.015 [σ²=0.198, R̂=1.0000]
350 -11.727-0.002j ± 0.019 [σ²=0.196, R̂=1.0028]
400 -11.830-0.013j ± 0.017 [σ²=0.105, R̂=1.0051]
450 -11.872+0.011j ± 0.010 [σ²=0.055, R̂=0.9991]

Comparing to exact diagonalization

Since this is a relatively small quantum system, we can still diagonalize the Hamiltonian using exact diagonalization. For this purpose, NetKet conveniently exposes a .to_sparse method that just converts the Hamiltonian into a scipy sparse matrix. Here we first obtain this sparse matrix, and then diagonalize it with scipy builtins.

import scipy
print("Exact energy is : ",exact_ens[0])
print("Relative error is : ", (abs(([0])/exact_ens[0])))
Exact energy is :  -11.932889012463688
Relative error is :  0.0034391959338140226