Source code for sax.nn.utils
""" Sax Neural Network Default Utilities """
from __future__ import annotations
from collections import namedtuple
from typing import Tuple
import jax.numpy as jnp
import pandas as pd
from ..saxtypes import ComplexArrayND
[docs]def cartesian_product(*arrays: ComplexArrayND) -> ComplexArrayND:
"""calculate the n-dimensional cartesian product of an arbitrary number of arrays"""
ixarrays = jnp.ix_(*arrays)
barrays = jnp.broadcast_arrays(*ixarrays)
sarrays = jnp.stack(barrays, -1)
assert isinstance(sarrays, jnp.ndarray)
product = sarrays.reshape(-1, sarrays.shape[-1])
assert isinstance(product, jnp.ndarray)
return product
[docs]def denormalize(
x: ComplexArrayND, mean: float = 0.0, std: float = 1.0
) -> ComplexArrayND:
"""denormalize an array with a given mean and standard deviation"""
return x * std + mean
norm = namedtuple("norm", ("mean", "std"))
[docs]def get_normalization(x: ComplexArrayND):
"""Get mean and standard deviation for a given array"""
if isinstance(x, (complex, float)):
return x, 0.0
return norm(x.mean(0), x.std(0))
[docs]def get_df_columns(df: pd.DataFrame, *names: str) -> Tuple[ComplexArrayND, ...]:
"""Get certain columns from a pandas DataFrame as jax.numpy arrays"""
tup = namedtuple("params", names)
params_list = []
for name in names:
column_np = df[name].values
column_jnp = jnp.array(column_np)
assert isinstance(column_jnp, jnp.ndarray)
params_list.append(column_jnp.ravel())
return tup(*params_list)
[docs]def normalize(x: ComplexArrayND, mean: float = 0.0, std: float = 1.0) -> ComplexArrayND:
"""normalize an array with a given mean and standard deviation"""
return (x - mean) / std