Source code for sax.nn.io

""" SAX Neural Network I/O Utilities """

from __future__ import annotations

import json
import os
import re
from typing import Callable, Dict, List, Optional, Tuple

import jax.numpy as jnp

from ..saxtypes import ComplexArrayND
from .core import dense, preprocess
from .utils import norm


[docs]def load_nn_weights_json(path: str) -> Dict[str, ComplexArrayND]: """Load json weights from given path""" path = os.path.abspath(os.path.expanduser(path)) weights = {} if os.path.exists(path): with open(path, "r") as file: for k, v in json.load(file).items(): _v = jnp.array(v, dtype=float) assert isinstance(_v, jnp.ndarray) weights[k] = _v return weights
[docs]def save_nn_weights_json(weights: Dict[str, ComplexArrayND], path: str): """Save json weights to given path""" path = os.path.abspath(os.path.expanduser(path)) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as file: _weights = {} for k, v in weights.items(): v = jnp.atleast_1d(jnp.array(v)) assert isinstance(v, jnp.ndarray) _weights[k] = v.tolist() json.dump(_weights, file)
[docs]def get_available_sizes( dirpath: str, prefix: str, input_names: Tuple[str, ...], output_names: Tuple[str, ...], ) -> List[Tuple[int, ...]]: """Get all available json weight hidden sizes given filename parameters > Note: this function does NOT return the input size and the output size of the neural network. ONLY the hidden sizes are reported. The input and output sizes can easily be derived from `input_names` (after preprocessing) and `output_names`. """ all_weightfiles = os.listdir(dirpath) possible_weightfiles = ( s for s in all_weightfiles if s.endswith(f"-{'-'.join(output_names)}.json") ) possible_weightfiles = ( s for s in possible_weightfiles if s.startswith(f"{prefix}-{'-'.join(input_names)}") ) possible_weightfiles = (re.sub("[^0-9x]", "", s) for s in possible_weightfiles) possible_weightfiles = (re.sub("^x*", "", s) for s in possible_weightfiles) possible_weightfiles = (re.sub("x[^0-9]*$", "", s) for s in possible_weightfiles) possible_hidden_sizes = (s.strip() for s in possible_weightfiles if s.strip()) possible_hidden_sizes = ( tuple(hs.strip() for hs in s.split("x") if hs.strip()) for s in possible_hidden_sizes ) possible_hidden_sizes = ( tuple(int(hs) for hs in s[1:-1]) for s in possible_hidden_sizes if len(s) > 2 ) possible_hidden_sizes = sorted( possible_hidden_sizes, key=lambda hs: (len(hs), max(hs)) ) return possible_hidden_sizes
[docs]def get_dense_weights_path( *sizes: int, input_names: Optional[Tuple[str, ...]] = None, output_names: Optional[Tuple[str, ...]] = None, dirpath: str = "weights", prefix: str = "dense", preprocess=preprocess, ): """Create the SAX conventional path for a given weight dictionary""" if input_names: num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0] sizes = (num_inputs,) + sizes if output_names: sizes = sizes + (len(output_names),) path = os.path.abspath(os.path.join(dirpath, prefix)) if input_names: path = f"{path}-{'-'.join(input_names)}" if sizes: path = f"{path}-{'x'.join(str(s) for s in sizes)}" if output_names: path = f"{path}-{'-'.join(output_names)}" return f"{path}.json"
[docs]def get_norm_path( *shape: int, input_names: Optional[Tuple[str, ...]] = None, output_names: Optional[Tuple[str, ...]] = None, dirpath: str = "norms", prefix: str = "norm", preprocess=preprocess, ): """Create the SAX conventional path for the normalization constants""" if input_names and output_names: raise ValueError( "To get the norm name, one can only specify " "`input_names` OR `output_names`." ) if input_names: num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0] shape = (num_inputs,) + shape if output_names: shape = shape + (len(output_names),) path = os.path.abspath(os.path.join(dirpath, prefix)) if input_names: path = f"{path}-{'-'.join(input_names)}" if shape: path = f"{path}-{'x'.join(str(s) for s in shape)}" if output_names: path = f"{path}-{'-'.join(output_names)}" return f"{path}.json"
class _PartialDense: def __init__(self, weights, x_norm, y_norm, input_names, output_names): self.weights = weights self.x_norm = x_norm self.y_norm = y_norm self.input_names = input_names self.output_names = output_names def __call__(self, *params: ComplexArrayND) -> ComplexArrayND: return dense(self.weights, *params, x_norm=self.x_norm, y_norm=self.y_norm) def __repr__(self): return ( f"{self.__class__.__name__}" f"{repr(self.input_names)}" f"->{repr(self.output_names)}" )
[docs]def load_nn_dense( *sizes: int, input_names: Optional[Tuple[str, ...]] = None, output_names: Optional[Tuple[str, ...]] = None, weightprefix="dense", weightdirpath="weights", normdirpath="norms", normprefix="norm", preprocess=preprocess, ) -> Callable: """Load a pre-trained dense model""" weights_path = get_dense_weights_path( *sizes, input_names=input_names, output_names=output_names, prefix=weightprefix, dirpath=weightdirpath, preprocess=preprocess, ) if not os.path.exists(weights_path): raise ValueError("Cannot find weights path for given parameters") x_norm_path = get_norm_path( input_names=input_names, prefix=normprefix, dirpath=normdirpath, preprocess=preprocess, ) if not os.path.exists(x_norm_path): raise ValueError("Cannot find normalization for input parameters") y_norm_path = get_norm_path( output_names=output_names, prefix=normprefix, dirpath=normdirpath, preprocess=preprocess, ) if not os.path.exists(x_norm_path): raise ValueError("Cannot find normalization for output parameters") weights = load_nn_weights_json(weights_path) x_norm_dict = load_nn_weights_json(x_norm_path) y_norm_dict = load_nn_weights_json(y_norm_path) x_norm = norm(x_norm_dict["mean"], x_norm_dict["std"]) y_norm = norm(y_norm_dict["mean"], y_norm_dict["std"]) partial_dense = _PartialDense(weights, x_norm, y_norm, input_names, output_names) return partial_dense