Source code for sax.multimode

""" SAX Multimode support """

from __future__ import annotations

from functools import wraps
from typing import Dict, Tuple, Union, cast, overload

import jax.numpy as jnp

from .saxtypes import (
    Model,
    SCoo,
    SDense,
    SDict,
    SType,
    _consolidate_sdense,
    is_model,
    is_multimode,
    is_scoo,
    is_sdense,
    is_sdict,
    is_singlemode,
)
from .utils import (
    block_diag,
    mode_combinations,
    validate_multimode,
    validate_not_mixedmode,
)


@overload
def multimode(S: Model, modes: Tuple[str, ...] = ("TE", "TM")) -> Model:
    ...


@overload
def multimode(S: SDict, modes: Tuple[str, ...] = ("TE", "TM")) -> SDict:
    ...


@overload
def multimode(S: SCoo, modes: Tuple[str, ...] = ("TE", "TM")) -> SCoo:
    ...


@overload
def multimode(S: SDense, modes: Tuple[str, ...] = ("TE", "TM")) -> SDense:
    ...


[docs]def multimode( S: Union[SType, Model], modes: Tuple[str, ...] = ("TE", "TM") ) -> Union[SType, Model]: """Convert a single mode model to a multimode model""" if is_model(S): model = cast(Model, S) @wraps(model) def new_model(**params): return multimode(model(**params), modes=modes) return cast(Model, new_model) S = cast(SType, S) validate_not_mixedmode(S) if is_multimode(S): validate_multimode(S, modes=modes) return S if is_sdict(S): return _multimode_sdict(cast(SDict, S), modes=modes) elif is_scoo(S): return _multimode_scoo(cast(SCoo, S), modes=modes) elif is_sdense(S): return _multimode_sdense(cast(SDense, S), modes=modes) else: raise ValueError("cannot convert to multimode. Unknown stype.")
def _multimode_sdict(sdict: SDict, modes: Tuple[str, ...] = ("TE", "TM")) -> SDict: multimode_sdict = {} _mode_combinations = mode_combinations(modes) for (p1, p2), value in sdict.items(): for m1, m2 in _mode_combinations: multimode_sdict[f"{p1}@{m1}", f"{p2}@{m2}"] = value return multimode_sdict def _multimode_scoo(scoo: SCoo, modes: Tuple[str, ...] = ("TE", "TM")) -> SCoo: Si, Sj, Sx, port_map = scoo num_ports = len(port_map) mode_map = ( {mode: i for i, mode in enumerate(modes)} if not isinstance(modes, dict) else cast(Dict, modes) ) _mode_combinations = mode_combinations(modes) Si_m = jnp.concatenate( [Si + mode_map[m] * num_ports for m, _ in _mode_combinations], -1 ) Sj_m = jnp.concatenate( [Sj + mode_map[m] * num_ports for _, m in _mode_combinations], -1 ) Sx_m = jnp.concatenate([Sx for _ in _mode_combinations], -1) port_map_m = { f"{port}@{mode}": idx + mode_map[mode] * num_ports for mode in modes for port, idx in port_map.items() } return Si_m, Sj_m, Sx_m, port_map_m def _multimode_sdense(sdense, modes=("TE", "TM")): Sx, port_map = sdense num_ports = len(port_map) mode_map = ( {mode: i for i, mode in enumerate(modes)} if not isinstance(modes, dict) else modes ) Sx_m = block_diag(*(Sx for _ in modes)) port_map_m = { f"{port}@{mode}": idx + mode_map[mode] * num_ports for mode in modes for port, idx in port_map.items() } return Sx_m, port_map_m @overload def singlemode(S: Model, mode: str = "TE") -> Model: ... @overload def singlemode(S: SDict, mode: str = "TE") -> SDict: ... @overload def singlemode(S: SCoo, mode: str = "TE") -> SCoo: ... @overload def singlemode(S: SDense, mode: str = "TE") -> SDense: ...
[docs]def singlemode(S: Union[SType, Model], mode: str = "TE") -> Union[SType, Model]: """Convert multimode model to a singlemode model""" if is_model(S): model = cast(Model, S) @wraps(model) def new_model(**params): return singlemode(model(**params), mode=mode) return cast(Model, new_model) S = cast(SType, S) validate_not_mixedmode(S) if is_singlemode(S): return S if is_sdict(S): return _singlemode_sdict(cast(SDict, S), mode=mode) elif is_scoo(S): return _singlemode_scoo(cast(SCoo, S), mode=mode) elif is_sdense(S): return _singlemode_sdense(cast(SDense, S), mode=mode) else: raise ValueError("cannot convert to multimode. Unknown stype.")
def _singlemode_sdict(sdict: SDict, mode: str = "TE") -> SDict: singlemode_sdict = {} for (p1, p2), value in sdict.items(): if p1.endswith(f"@{mode}") and p2.endswith(f"@{mode}"): p1, _ = p1.split("@") p2, _ = p2.split("@") singlemode_sdict[p1, p2] = value return singlemode_sdict def _singlemode_scoo(scoo: SCoo, mode: str = "TE") -> SCoo: Si, Sj, Sx, port_map = scoo # no need to touch the data... # just removing some ports from the port map should be enough port_map = { port.split("@")[0]: idx for port, idx in port_map.items() if port.endswith(f"@{mode}") } return Si, Sj, Sx, port_map def _singlemode_sdense(sdense: SDense, mode: str = "TE") -> SDense: Sx, port_map = sdense # no need to touch the data... # just removing some ports from the port map should be enough port_map = { port.split("@")[0]: idx for port, idx in port_map.items() if port.endswith(f"@{mode}") } return _consolidate_sdense(Sx, port_map)