Source code for sax.nn.core

""" SAX Neural Network Core Utils """

from __future__ import annotations

from typing import Callable, Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from ..saxtypes import Array, ComplexArrayND
from .utils import denormalize, normalize

[docs]def preprocess(*params: ComplexArrayND) -> ComplexArrayND: """preprocess parameters > Note: (1) all arguments are first casted into the same shape. (2) then pairs of arguments are divided into each other to create relative arguments. (3) all arguments are then stacked into one big tensor """ x = jnp.stack(jnp.broadcast_arrays(*params), -1) assert isinstance(x, jnp.ndarray) to_concatenate = [x] for i in range(1, x.shape[-1]): _x = jnp.roll(x, shift=i, axis=-1) to_concatenate.append(x / _x) to_concatenate.append(_x / x) x = jnp.concatenate(to_concatenate, -1) assert isinstance(x, jnp.ndarray) return x
[docs]def dense( weights: Dict[str, Array], *params: ComplexArrayND, x_norm: Tuple[float, float] = (0.0, 1.0), y_norm: Tuple[float, float] = (0.0, 1.0), preprocess: Callable = preprocess, activation: Callable = jax.nn.leaky_relu, ) -> ComplexArrayND: """simple dense neural network""" x_mean, x_std = x_norm y_mean, y_std = y_norm x = preprocess(*params) x = normalize(x, mean=x_mean, std=x_std) for i in range(len([w for w in weights if w.startswith("w")])): x = activation(x @ weights[f"w{i}"] + weights.get(f"b{i}", 0.0)) y = denormalize(x, mean=y_mean, std=y_std) return y
[docs]def generate_dense_weights( key: Union[int, Array], sizes: Tuple[int, ...], input_names: Optional[Tuple[str, ...]] = None, output_names: Optional[Tuple[str, ...]] = None, preprocess=preprocess, ) -> Dict[str, ComplexArrayND]: """Generate the weights for a dense neural network""" if isinstance(key, int): random_key = jax.random.PRNGKey(key) else: random_key = key sizes = tuple(s for s in sizes) if input_names: arr = preprocess(*jnp.ones(len(input_names))) assert isinstance(arr, jnp.ndarray) sizes = (arr.shape[-1],) + sizes if output_names: sizes = sizes + (len(output_names),) keys = jax.random.split(random_key, 2 * len(sizes)) rand = jax.nn.initializers.lecun_normal() weights = {} for i, (m, n) in enumerate(zip(sizes[:-1], sizes[1:])): weights[f"w{i}"] = rand(keys[2 * i], (m, n)) weights[f"b{i}"] = rand(keys[2 * i + 1], (1, n)).ravel() return weights