netket.nn.MaskedConv1D

class netket.nn.MaskedConv1D(features, kernel_size, kernel_dilation, exclusive, feature_group_count=1, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Bases: flax.linen.module.Module

1D convolution module with mask for autoregressive NN.

features

number of convolution filters.

Type

int

kernel_size

length of the convolutional kernel.

Type

int

kernel_dilation

dilation factor of the convolution kernel.

Type

int

exclusive

True if an output element does not depend on the input element at the same index.

Type

bool

feature_group_count

integer, default 1. If specified divides the input features into groups.

Type

int

use_bias

whether to add a bias to the output (default: True).

Type

bool

dtype

the dtype of the computation (default: float64).

Type

Any

precision

numerical precision of the computation, see jax.lax.Precision for details.

Type

Any

kernel_init

initializer for the convolutional kernel.

Type

Callable[[Any, Sequence[int], Any], Union[numpy.ndarray, jaxlib.xla_extension.DeviceArray, jax.core.Tracer]]

bias_init

initializer for the bias.

Type

Callable[[Any, Sequence[int], Any], Union[numpy.ndarray, jaxlib.xla_extension.DeviceArray, jax.core.Tracer]]

__init__(features, kernel_size, kernel_dilation, exclusive, feature_group_count=1, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)

Initialize self. See help(type(self)) for accurate signature.

Parameters
Return type

None

Attributes
feature_group_count: int = 1
name: str = None
parent: Optional[Union[Type[Module], Type[Scope], Type[_Sentinel]]] = <flax.linen.module._Sentinel object>
precision: Any = None
scope = None
use_bias: bool = True
variables

Returns the variables in this module.

Return type

Mapping[str, Mapping[str, Any]]

Methods
__call__(inputs)[source]

Applies a masked convolution to the inputs. For 1D convolution, there is not really a mask. We only need to apply appropriate padding.

Parameters

inputs (Union[ndarray, DeviceArray, Tracer]) – input data with dimensions (batch, length, features).

Return type

Union[ndarray, DeviceArray, Tracer]

Returns

The convolved data.

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

model = Transformer()
encoded = model.apply({'params': params}, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

encoded = model.apply({'params': params}, x, method=model.encode)

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

def other_fn(instance, ...):
  instance.some_module_attr(...)
  ...

model.apply({'params': params}, x, method=other_fn)
Parameters
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • rngs (Optional[Dict[str, Any]]) – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.

  • method (Optional[Callable[…, Any]]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module.

  • mutable (Union[bool, str, Container[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Return type

Union[Any, Tuple[Any, FrozenDict[str, Mapping[str, Any]]]]

Returns

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.

bias_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)
bind(variables, *args, rngs=None, mutable=False)

Creates an interactive Module instance by binding variables and RNGs.

bind provides an “interactive” instance of a Module directly without transforming a function with apply. This is particalary useful for debugging and interactive use cases like notebooks where a function would limit the ability split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage to use apply() instead.

Example:

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = nn.Dense(3)
    self.decoder = nn.Dense(5)

ae = AutoEncoder()
model = ae.bind(variables)
z = model.encode(x)
x_reconstructed = model.decode(z)
Parameters
  • variables (Mapping[str, Mapping[str, Any]]) –

    A dictionary containing variables keyed by variable

    collections. See flax.core.variables for more details about variables.

    rngs: a dict of PRNGKeys to initialize the PRNG sequences. mutable: Can be bool, str, or list. Specifies which collections should be

    treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • rngs (Optional[Dict[str, Any]]) –

  • mutable (Union[bool, str, Container[str], flax.core.scope.DenyList]) –

Returns

A copy of this instance with bound variables and RNGs.

clone(*, parent=None, **updates)

Creates a clone of this Module, with optionally updated arguments.

Parameters
  • parent (Union[Scope, Module, None]) – The parent of the clone. The clone will have no parent if no explicit parent is specified.

  • **updates – Attribute updates.

Return type

Module

Returns

A clone of the this Module with the updated attributes and parent.

get_variable(col, name, default=None)

Retrieves the value of a Variable.

Parameters
  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • default (Optional[~T]) – the default value to return if the variable does not exist in this scope.

Return type

~T

Returns

The value of the input variable, of the default value if the variable doesn’t exist in this scope.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

See flax.core.variables for more explanation on variables and collections.

Parameters
  • col (str) – The variable collection name.

  • name (str) – The name of the variable.

Return type

bool

Returns

True if the variable exists.

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)

Initializes a module method with variables and returns modified variables.

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

jit_init = jax.jit(SomeModule.init)
jit_init(rng, jnp.ones(input_shape, jnp.float32))
Parameters
  • rngs (Union[Any, Dict[str, Any]]) – The rngs for the variable collections.

  • method (Optional[Callable[…, Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Container[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

Return type

FrozenDict[str, Mapping[str, Any]]

Returns

The initialized variable dict.

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)

Initializes a module method with variables and returns output and modified variables.

Parameters
  • rngs (Union[Any, Dict[str, Any]]) – The rngs for the variable collections.

  • method (Optional[Callable[…, Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Container[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

Return type

Tuple[Any, FrozenDict[str, Mapping[str, Any]]]

Returns

(output, vars)`, where vars are is a dict of the modified collections.

is_mutable_collection(col)

Returns true if the collection col is mutable.

Return type

bool

Parameters

col (str) –

kernel_init(shape, dtype=<class 'jax._src.numpy.lax_numpy.complex64'>)
make_rng(name)

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility.

TODO: Link to Flax RNG design note.

Parameters

name (str) – The RNG sequence name.

Return type

Any

Returns

The newly generated RNG key.

param(name, init_fn, *init_args)

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args:

mean = self.param('mean', lecun_normal(), (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Parameters
  • name (str) – The parameter name.

  • init_fn (Callable[…, ~T]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.

  • *init_args – The arguments to pass to init_fn.

Return type

~T

Returns

The value of the initialized parameter.

setup()

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    h = nn.Dense(4)(x)
    self.sow('intermediates', 'h', h)
    return nn.Dense(2)(h)
y, state = Foo.apply(params, x, mutable=['intermediates'])
print(state['intermediates'])  # {'h': (...,)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    init_fn = lambda: 0
    reduce_fn = lambda a, b: a + b
    self.sow('intermediates', 'h', x,
             init_fn=init_fn, reduce_fn=reduce_fn)
    self.sow('intermediates', 'h', x * 2,
             init_fn=init_fn, reduce_fn=reduce_fn)
    return x
y, state = Foo.apply(params, 1, mutable=['intermediates'])
print(state['intermediates'])  # ==> {'h': 3}
Parameters
  • col (str) – The name of the variable collection.

  • name (str) – The name of the variable.

  • value (~T) – The value of the variable.

  • reduce_fn (Callable[[~K, ~T], ~K]) – The function used to combine the existing value with the new value the default is to append the value to a tuple.

  • init_fn (Callable[[], ~K]) – For the first value stored reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type

bool

Returns

True if the value has been stored successfully, False otherwise.

variable(col, name, init_fn, *init_args)

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the “params” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

key = self.make_rng('stats')
mean = self.variable('stats', 'mean', lecun_normal(), key, (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Parameters
  • col (str) – The variable collection name.

  • name (str) – The variable name.

  • init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module.

  • *init_args – The arguments to pass to init_fn.

Return type

Variable

Returns

A flax.core.variables.Variable that can be read or set via “.value” attribute. Throws an error if the variable exists already.