Source code for sax.models

""" SAX Default Models """

from __future__ import annotations

from functools import lru_cache as cache
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

from .saxtypes import FloatArrayND, Model, SCoo, SDict
from .utils import get_inputs_outputs, reciprocal


[docs]def straight( *, wl: FloatArrayND | float = 1.55, wl0: float = 1.55, neff: float = 2.34, ng: float = 3.4, length: float = 10.0, loss: float = 0.0, ) -> SDict: """a simple straight waveguide model""" dwl = wl - wl0 dneff_dwl = (ng - neff) / wl0 _neff = neff - dwl * dneff_dwl phase = 2 * jnp.pi * _neff * length / wl amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex) transmission = amplitude * jnp.exp(1j * phase) sdict = reciprocal( { ("in0", "out0"): transmission, } ) return sdict
[docs]def coupler(*, coupling: float = 0.5) -> SDict: """a simple coupler model""" kappa = coupling**0.5 tau = (1 - coupling) ** 0.5 sdict = reciprocal( { ("in0", "out0"): tau, ("in0", "out1"): 1j * kappa, ("in1", "out0"): 1j * kappa, ("in1", "out1"): tau, } ) return sdict
def _validate_ports( ports, num_inputs, num_outputs, diagonal ) -> Tuple[Tuple[str, ...], Tuple[str, ...], int, int]: if ports is None: if num_inputs is None or num_outputs is None: raise ValueError( "if not ports given, you must specify how many input ports " "and how many output ports a model has." ) input_ports = [f"in{i}" for i in range(num_inputs)] output_ports = [f"out{i}" for i in range(num_outputs)] else: if num_inputs is not None: if num_outputs is None: raise ValueError( "if num_inputs is given, num_outputs should be given as well." ) if num_outputs is not None: if num_inputs is None: raise ValueError( "if num_outputs is given, num_inputs should be given as well." ) if num_inputs is not None and num_outputs is not None: if num_inputs + num_outputs != len(ports): raise ValueError("num_inputs + num_outputs != len(ports)") input_ports = ports[:num_inputs] output_ports = ports[num_inputs:] else: input_ports, output_ports = get_inputs_outputs(ports) num_inputs = len(input_ports) num_outputs = len(output_ports) if diagonal: if num_inputs != num_outputs: raise ValueError( "Can only have a diagonal passthru if number " "of input ports equals the number of output ports!" ) return tuple(input_ports), tuple(output_ports), num_inputs, num_outputs
[docs]@cache def unitary( num_inputs: Optional[int] = None, num_outputs: Optional[int] = None, ports: Optional[Tuple[str, ...]] = None, *, jit=True, reciprocal=True, diagonal=False, ) -> Model: input_ports, output_ports, num_inputs, num_outputs = _validate_ports( ports, num_inputs, num_outputs, diagonal ) assert num_inputs is not None and num_outputs is not None # let's create the squared S-matrix: N = max(num_inputs, num_outputs) S = jnp.zeros((2 * N, 2 * N), dtype=float) if not diagonal: S = S.at[:N, N:].set(1) else: r = jnp.arange( N, dtype=int ) # reciprocal only works if num_inputs == num_outputs! S = S.at[r, N + r].set(1) if reciprocal: if not diagonal: S = S.at[N:, :N].set(1) else: r = jnp.arange( N, dtype=int ) # reciprocal only works if num_inputs == num_outputs! S = S.at[N + r, r].set(1) # Now we need to normalize the squared S-matrix U, s, V = jnp.linalg.svd(S, full_matrices=False) S = jnp.sqrt(U @ jnp.diag(jnp.where(s > 1e-12, 1, 0)) @ V) # Now create subset of this matrix we're interested in: r = jnp.concatenate( [jnp.arange(num_inputs, dtype=int), N + jnp.arange(num_outputs, dtype=int)], 0 ) S = S[r, :][:, r] # let's convert it in SCOO format: Si, Sj = jnp.where(S > 1e-6) Sx = S[Si, Sj] # the last missing piece is a port map: pm = { **{p: i for i, p in enumerate(input_ports)}, **{p: i + num_inputs for i, p in enumerate(output_ports)}, } def func(wl: float = 1.5) -> SCoo: wl_ = jnp.asarray(wl) Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape)) return Si, Sj, Sx_, pm func.__name__ = f"unitary_{num_inputs}_{num_outputs}" func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}" if jit: return jax.jit(func) return func
[docs]@cache def copier( num_inputs: Optional[int] = None, num_outputs: Optional[int] = None, ports: Optional[Tuple[str, ...]] = None, *, jit=True, reciprocal=True, diagonal=False, ) -> Model: input_ports, output_ports, num_inputs, num_outputs = _validate_ports( ports, num_inputs, num_outputs, diagonal ) assert num_inputs is not None and num_outputs is not None # let's create the squared S-matrix: S = jnp.zeros((num_inputs + num_outputs, num_inputs + num_outputs), dtype=float) if not diagonal: S = S.at[:num_inputs, num_inputs:].set(1) else: r = jnp.arange( num_inputs, dtype=int ) # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs! S = S.at[r, num_inputs + r].set(1) if reciprocal: if not diagonal: S = S.at[num_inputs:, :num_inputs].set(1) else: # reciprocal only works if num_inputs == num_outputs! r = jnp.arange(num_inputs, dtype=int) # == range(num_outputs) S = S.at[num_inputs + r, r].set(1) # let's convert it in SCOO format: Si, Sj = jnp.where(S > 1e-6) Sx = S[Si, Sj] # the last missing piece is a port map: pm = { **{p: i for i, p in enumerate(input_ports)}, **{p: i + num_inputs for i, p in enumerate(output_ports)}, } def func(wl: float = 1.5) -> SCoo: wl_ = jnp.asarray(wl) Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape)) return Si, Sj, Sx_, pm func.__name__ = f"unitary_{num_inputs}_{num_outputs}" func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}" if jit: return jax.jit(func) return func
[docs]@cache def passthru( num_links: Optional[int] = None, ports: Optional[Tuple[str, ...]] = None, *, jit=True, reciprocal=True, ) -> Model: passthru = unitary( num_links, num_links, ports, jit=jit, reciprocal=reciprocal, diagonal=True ) passthru.__name__ = f"passthru_{num_links}_{num_links}" passthru.__qualname__ = f"passthru_{num_links}_{num_links}" if jit: return jax.jit(passthru) return passthru
models = { "copier": copier, "coupler": coupler, "passthru": passthru, "straight": straight, "unitary": unitary, }
[docs]def get_models(copy: bool = True): if copy: return {**models} return models