""" SAX Types and type coercions """
from __future__ import annotations
import functools
import inspect
from collections.abc import Callable as CallableABC
from typing import Any, Callable, Dict, Tuple, Union, cast, overload
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array as Array
from jaxtyping import ArrayLike as ArrayLike
from jaxtyping import Complex as Complex
from jaxtyping import Float as Float
from jaxtyping import Int as Int
from natsort import natsorted
IntArray1D = Int[Array, " dim"]
""" One dimensional integer array """
FloatArray1D = Complex[Array, " dim"]
""" One dimensional float array """
ComplexArray1D = Complex[Array, " dim"]
""" One dimensional complex array """
IntArrayND = Int[Array, "..."]
""" N-dimensional integer array """
FloatArrayND = Complex[Array, "..."]
""" N-dimensional float array """
ComplexArrayND = Complex[Array, "..."]
""" N-dimensional complex array """
PortMap = Dict[str, int]
""" A mapping from a port name (str) to a port index (int) """
PortCombination = Tuple[str, str]
""" A combination of two port names (str, str) """
SDict = Dict[PortCombination, ComplexArrayND]
""" A mapping from a port combination to an S-parameter or an array of S-parameters
Example:
.. code-block::
sdict: sax.SDict = {
("in0", "out0"): 3.0,
}
"""
SDense = Tuple[ComplexArrayND, PortMap]
""" A dense S-matrix (2D array) or multidimensional batched S-matrix (N+2)-D array
combined with a port map. If (N+2)-D array the S-matrix dimensions are the last two.
Example:
.. code-block::
Sd = jnp.arange(9, dtype=float).reshape(3, 3)
port_map = {"in0": 0, "in1": 2, "out0": 1}
sdense = Sd, port_map
"""
SCoo = Tuple[IntArray1D, IntArray1D, ComplexArrayND, PortMap]
""" A sparse S-matrix in COO format (recommended for internal library use only)
An `SCoo` is a sparse matrix based representation of an S-matrix consisting of three arrays and a port map.
The three arrays represent the input port indices [`int`], output port indices [`int`] and the S-matrix values [`ComplexFloat`] of the sparse matrix.
The port map maps a port name [`str`] to a port index [`int`]. Only these four arrays **together** and in this specific **order** are considered a valid `SCoo` representation!
Example:
.. code-block::
Si = jnp.arange(3, dtype=int)
Sj = jnp.array([0, 1, 0], dtype=int)
Sx = jnp.array([3.0, 4.0, 1.0])
port_map = {"in0": 0, "in1": 2, "out0": 1}
scoo: sax.SCoo = (Si, Sj, Sx, port_map)
"""
Settings = Dict[str, Union["Settings", FloatArrayND, ComplexArrayND]]
""" A (possibly recursive) mapping from a setting name to a float or complex value or array
Example:
.. code-block::
mzi_settings: sax.Settings = {
"wl": 1.5, # global settings
"lft": {"coupling": 0.5}, # settings for the left coupler
"top": {"neff": 3.4}, # settings for the top waveguide
"rgt": {"coupling": 0.3}, # settings for the right coupler
}
"""
SType = Union[SDict, SCoo, SDense]
""" An SDict, SDense or SCOO """
Model = Callable[..., SType]
""" A keyword-only function producing an SType """
ModelFactory = Callable[..., Model]
""" A keyword-only function producing a Model """
[docs]def is_float(x: Any) -> bool:
"""Check if an object is a `Float`"""
if isinstance(x, float):
return True
if isinstance(x, np.ndarray):
return x.dtype in (np.float16, np.float32, np.float64, np.float128)
if isinstance(x, jnp.ndarray):
return x.dtype in (jnp.float16, jnp.float32, jnp.float64)
return False
[docs]def is_complex(x: Any) -> bool:
"""check if an object is a `ComplexFloat`"""
if isinstance(x, complex):
return True
if isinstance(x, np.ndarray):
return x.dtype in (np.complex64, np.complex128)
if isinstance(x, jnp.ndarray):
return x.dtype in (jnp.complex64, jnp.complex128)
return False
[docs]def is_complex_float(x: Any) -> bool:
"""check if an object is either a `ComplexFloat` or a `Float`"""
return is_float(x) or is_complex(x)
[docs]def is_sdict(x: Any) -> bool:
"""check if an object is an `SDict` (a SAX S-dictionary)"""
return isinstance(x, dict)
[docs]def is_scoo(x: Any) -> bool:
"""check if an object is an `SCoo` (a SAX sparse S representation in COO-format)"""
return isinstance(x, (tuple, list)) and len(x) == 4
[docs]def is_sdense(x: Any) -> bool:
"""check if an object is an `SDense` (a SAX dense S-matrix representation)"""
return isinstance(x, (tuple, list)) and len(x) == 2
[docs]def is_model(model: Any) -> bool:
"""check if a callable is a `Model` (a callable returning an `SType`)"""
if not callable(model):
return False
try:
sig = inspect.signature(model)
except ValueError:
return False
for param in sig.parameters.values():
if param.default is inspect.Parameter.empty:
return False # a proper SAX model does not have any positional arguments.
if _is_callable_annotation(sig.return_annotation): # model factory
return False
return True
def _is_callable_annotation(annotation: Any) -> bool:
"""check if an annotation is `Callable`-like"""
if isinstance(annotation, str):
# happens when
# was imported at the top of the file...
return annotation.startswith("Callable") or annotation.endswith("Model")
# TODO: this is not a very robust check...
try:
return annotation.__origin__ == CallableABC
except AttributeError:
return False
[docs]def is_model_factory(model: Any) -> bool:
"""check if a callable is a model function."""
if not callable(model):
return False
sig = inspect.signature(model)
if _is_callable_annotation(sig.return_annotation): # model factory
return True
return False
[docs]def validate_model(model: Callable):
"""Validate the parameters of a model"""
positional_arguments = []
for param in inspect.signature(model).parameters.values():
if param.default is inspect.Parameter.empty:
positional_arguments.append(param.name)
if positional_arguments:
raise ValueError(
f"model '{model}' takes positional "
f"arguments {', '.join(positional_arguments)} "
"and hence is not a valid SAX Model! "
"A SAX model should ONLY take keyword arguments (or no arguments at all)."
)
[docs]def is_stype(stype: Any) -> bool:
"""check if an object is an SDict, SCoo or SDense"""
return is_sdict(stype) or is_scoo(stype) or is_sdense(stype)
[docs]def is_singlemode(S: Any) -> bool:
"""check if an stype is single mode"""
if not is_stype(S):
return False
ports = _get_ports(S)
return not any(("@" in p) for p in ports)
def _get_ports(S: SType):
if is_sdict(S):
S = cast(SDict, S)
ports_set = {p1 for p1, _ in S} | {p2 for _, p2 in S}
return tuple(natsorted(ports_set))
else:
*_, ports_map = S
assert isinstance(ports_map, dict)
return tuple(natsorted(ports_map.keys()))
[docs]def is_multimode(S: Any) -> bool:
"""check if an stype is single mode"""
if not is_stype(S):
return False
ports = _get_ports(S)
return all(("@" in p) for p in ports)
[docs]def is_mixedmode(S: Any) -> bool:
"""check if an stype is neither single mode nor multimode (hence invalid)"""
return not is_singlemode(S) and not is_multimode(S)
@overload
def sdict(S: Model) -> Model:
...
@overload
def sdict(S: SType) -> SDict:
...
[docs]def sdict(S: Union[Model, SType]) -> Union[Model, SType]:
"""Convert an `SCoo` or `SDense` to `SDict`"""
if is_model(S):
model = cast(Model, S)
@functools.wraps(model)
def wrapper(**kwargs):
return sdict(model(**kwargs))
return wrapper
elif is_scoo(S):
x_dict = _scoo_to_sdict(*cast(SCoo, S))
elif is_sdense(S):
x_dict = _sdense_to_sdict(*cast(SDense, S))
elif is_sdict(S):
x_dict = cast(SDict, S)
else:
raise ValueError("Could not convert arguments to sdict.")
return x_dict
def _scoo_to_sdict(
Si: IntArray1D,
Sj: IntArray1D,
Sx: ComplexArrayND,
ports_map: Dict[str, int],
) -> SDict:
sdict = {}
inverse_ports_map = {int(i): p for p, i in ports_map.items()}
for i, (si, sj) in enumerate(zip(Si, Sj)):
input_port = inverse_ports_map.get(int(si), "")
output_port = inverse_ports_map.get(int(sj), "")
sdict[input_port, output_port] = Sx[..., i]
sdict = {(p1, p2): v for (p1, p2), v in sdict.items() if p1 and p2}
return sdict
def _sdense_to_sdict(S: Array, ports_map: Dict[str, int]) -> SDict:
sdict = {}
for p1, i in ports_map.items():
for p2, j in ports_map.items():
sdict[p1, p2] = S[..., i, j]
return sdict
@overload
def scoo(S: Callable) -> Callable:
...
@overload
def scoo(S: SType) -> SCoo:
...
[docs]def scoo(S: Union[Callable, SType]) -> Union[Callable, SCoo]:
"""Convert an `SDict` or `SDense` to `SCoo`"""
if is_model(S):
model = cast(Model, S)
@functools.wraps(model)
def wrapper(**kwargs):
return scoo(model(**kwargs))
return wrapper
elif is_scoo(S):
S = cast(SCoo, S)
elif is_sdense(S):
S = _sdense_to_scoo(*cast(SDense, S))
elif is_sdict(S):
S = _sdict_to_scoo(cast(SDict, S))
else:
raise ValueError("Could not convert arguments to scoo.")
return S
def _consolidate_sdense(S, pm):
idxs = list(pm.values())
S = S[..., idxs, :][..., :, idxs]
pm = {p: i for i, p in enumerate(pm)}
return S, pm
def _sdense_to_scoo(S: Array, ports_map: Dict[str, int]) -> SCoo:
S, ports_map = _consolidate_sdense(S, ports_map)
Sj, Si = jnp.meshgrid(jnp.arange(S.shape[-1]), jnp.arange(S.shape[-2]))
return Si.ravel(), Sj.ravel(), S.reshape(*S.shape[:-2], -1), ports_map
def _sdict_to_scoo(sdict: SDict) -> SCoo:
all_ports = {}
for p1, p2 in sdict:
all_ports[p1] = None
all_ports[p2] = None
ports_map = {p: int(i) for i, p in enumerate(all_ports)}
Sx = jnp.stack(jnp.broadcast_arrays(*sdict.values()), -1)
Si = jnp.array([ports_map[p] for p, _ in sdict])
Sj = jnp.array([ports_map[p] for _, p in sdict])
return Si, Sj, Sx, ports_map
@overload
def sdense(S: Callable) -> Callable:
...
@overload
def sdense(S: SType) -> SDense:
...
[docs]def sdense(S: Union[Callable, SType]) -> Union[Callable, SDense]:
"""Convert an `SDict` or `SCoo` to `SDense`"""
if is_model(S):
model = cast(Model, S)
@functools.wraps(model)
def wrapper(**kwargs):
return sdense(model(**kwargs))
return wrapper
if is_sdict(S):
S = _sdict_to_sdense(cast(SDict, S))
elif is_scoo(S):
S = _scoo_to_sdense(*cast(SCoo, S))
elif is_sdense(S):
S = cast(SDense, S)
else:
raise ValueError("Could not convert arguments to sdense.")
return S
def _scoo_to_sdense(
Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int]
) -> SDense:
n_col = len(ports_map)
S = jnp.zeros((*Sx.shape[:-1], n_col, n_col), dtype=complex)
S = S.at[..., Si, Sj].add(Sx)
return S, ports_map
def _sdict_to_sdense(sdict: SDict) -> SDense:
Si, Sj, Sx, ports_map = _sdict_to_scoo(sdict)
return _scoo_to_sdense(Si, Sj, Sx, ports_map)
[docs]def modelfactory(func):
"""Decorator that marks a function as `ModelFactory`"""
sig = inspect.signature(func)
if _is_callable_annotation(sig.return_annotation): # already model factory
return func
func.__signature__ = sig.replace(return_annotation=Model)
return func