Skip to content

sax

SAX: S + Autograd + XLA.

Modules:

Name Description
backends

SAX Backends.

circuits

SAX Circuit Definition.

constants

Constants and magic numbers.

into

Type conversion utilities for SAX.

loss

SAX Loss Functions.

models

SAX Models.

multimode

SAX Multimode support.

netlists

Netlist utilities.

ports

Port naming strategies.

s

SAX S-Matrix utilities.

saxtypes

All types and type-validators used in SAX.

utils

General SAX Utilities.

Classes:

Name Description
CircuitInfo

Information about a SAX circuit function.

IOLike

Protocol for file-like objects that can be read from.

Net

A logical connection between two ports.

Netlist

A complete netlist definition for an optical circuit.

Normalization

Normalization parameters for an array.

Placement

Physical placement information for an instance.

PortNamer

Port naming class to encapsulate port naming logic.

try_into

Optional type converter utility that returns None on failure.

Functions:

Name Description
analyze_circuit_additive

Analyze circuit topology for the additive backend.

analyze_circuit_fg

Analyze circuit topology for the Filipsson-Gunnar backend.

analyze_circuit_forward

Analyze circuit topology for the forward-only backend.

analyze_instances_additive

Analyze circuit instances for the additive backend.

analyze_instances_fg

Analyze circuit instances for the Filipsson-Gunnar backend.

analyze_instances_forward

Analyze circuit instances for the forward-only backend.

block_diag

Create block diagonal matrix with arbitrary batch dimensions.

cartesian_product

Calculate the n-dimensional Cartesian product of input arrays.

circuit

Create a circuit function for a given netlist.

clean_string

Clean a string to create a valid Python identifier.

denormalize

Denormalize an array using provided normalization parameters.

draw_dag

Draw a directed acyclic graph (DAG) representing circuit dependencies.

evaluate_circuit_additive

Evaluate circuit S-matrix using additive path-based method.

evaluate_circuit_fg

Evaluate circuit S-matrix using the Filipsson-Gunnar algorithm.

evaluate_circuit_forward

Evaluate circuit S-matrix using forward-only propagation.

flatten_dict

Flatten a nested dictionary into a single-level dictionary.

flatten_netlist

Flatten a recursive netlist into a single flat netlist.

get_mode

Extract the mode from a port@mode string.

get_modes

Extract the optical modes from a multimode S-matrix.

get_port_combinations

Extract all port pair combinations from an S-matrix.

get_port_naming_strategy

Get the current port naming strategy.

get_ports

Extract port names from an S-matrix.

get_required_circuit_models

Determine which component models are required for a given netlist.

get_settings

Extract default parameter settings from a SAX model function.

grouped_interp

Perform grouped phase interpolation for optical phase data.

hash_dict

Compute a hash value for a dictionary.

huber_loss

Compute Huber loss between two complex arrays.

l2_reg

Compute L2 regularization loss for model weights.

load_netlist

Load a SAX netlist from YAML content or file.

load_recursive_netlist

Load a SAX recursive netlist from a directory of YAML files.

maybe

Create a safe version of a function that returns None on exceptions.

merge_dicts

Merge multiple dictionaries with support for nested merging.

mse

Compute mean squared error between two complex arrays.

netlist

Convert a netlist to recursive netlist format.

normalization

Calculate normalization parameters (mean and std) for an array.

normalize

Normalize an array using provided normalization parameters.

read

Read content from string, file path, or file-like object.

reciprocal

Make an SDict S-matrix reciprocal by ensuring S[i,j] = S[j,i].

rename_params

Rename the parameters of a model function.

rename_ports

Rename the ports of an S-matrix or model.

replace_kwargs

Change the kwargs signature of a function.

scoo

Convert an S-matrix to SCoo (coordinate) format.

sdense

Convert an S-matrix to SDense (dense matrix) format.

sdict

Convert an S-matrix to SDict (dictionary) format.

set_port_naming_strategy

Set the port naming strategy for the default SAX models.

singlemode

Convert a multimode S-matrix or model to single-mode.

unflatten_dict

Unflatten a dictionary by splitting keys and creating nested structure.

update_settings

Update a nested settings dictionary with new parameter values.

wl_c

Generate wavelength array in the C-band (Conventional band: 1530-1565 nm).

wl_e

Generate wavelength array in the E-band (Extended band: 1360-1460 nm).

wl_l

Generate wavelength array in the L-band (Long band: 1565-1625 nm).

wl_o

Generate wavelength array in the O-band (Original band: 1260-1360 nm).

wl_s

Generate wavelength array in the S-band (Short band: 1460-1530 nm).

Attributes:

Name Type Description
AnyNetlist TypeAlias

Any valid netlist format: flat, recursive, or simplified dictionary.

ArrayLike TypeAlias

Anything that can turn into an array with ndim>=1.

Backend TypeAlias

Available SAX backend algorithms for circuit simulation.

BackendLike TypeAlias

Backend specification allowing 'default' to use the system default backend.

Bool TypeAlias

Any boolean value (Python bool or NumPy boolean).

BoolArray TypeAlias

N-dimensional boolean array (JAX Array with boolean dtype).

BoolArrayLike TypeAlias

Anything that can be cast into an N-dimensional Bool array without loss of data.

BoolLike TypeAlias

Anything that can be cast into a Bool without loss of data.

C_M_S float

Speed of light in vacuum (m/s).

C_UM_S float

Speed of light in vacuum (μm/s).

Complex TypeAlias

Any complex number (Python complex or NumPy complex floating).

ComplexArray TypeAlias

N-dimensional complex array (JAX Array with complex dtype).

ComplexArray1D TypeAlias

1-dimensional complex array (JAX Array with complex dtype).

ComplexArray1DLike TypeAlias

Anything that can be cast into a 1-dimensional complex array without loss of data.

ComplexArrayLike TypeAlias

Anything that can be cast into an N-dim Complex array without loss of data.

ComplexLike TypeAlias

Anything that can be cast into a Complex without loss of data.

Component TypeAlias

The name of a component model (must be a valid Python identifier).

Connections TypeAlias

A mapping defining point-to-point connections between instance ports.

DEFAULT_MODE str

Default optical mode.

DEFAULT_MODES tuple[str, ...]

Default multimode configuration.

DEFAULT_WL_STEP float

Default wavelength step for array generation (μm).

EPS float

Small numerical epsilon for determining which s-values can be considered zero.

Float TypeAlias

Any floating-point number (Python float or NumPy floating).

FloatArray TypeAlias

N-dimensional floating-point array (JAX Array with float dtype).

FloatArray1D TypeAlias

1-dimensional floating-point array (JAX Array with float dtype).

FloatArray1DLike TypeAlias

Anything that can be cast into a 1-dimensional float array without loss of data.

FloatArray2D TypeAlias

2-dimensional floating-point array (JAX Array with float dtype).

FloatArray2DLike TypeAlias

Anything that can be cast into a 2-dimensional float array without loss of data.

FloatArrayLike TypeAlias

Anything that can be cast into an N-dimensional Float array without loss of data.

FloatLike TypeAlias

Anything that can be cast into a Float without loss of data.

Instance TypeAlias

An component instantiation in a netlist with optional settings and array config.

InstanceName TypeAlias

An instance name allowing dots and angle brackets for hierarchical naming.

InstancePort TypeAlias

An instance port reference in the format 'instance_name,port_name'.

Instances TypeAlias

A mapping from instance names to their definitions.

Int TypeAlias

Any signed integer (Python int or NumPy signed integer).

IntArray TypeAlias

N-dimensional signed integer array (JAX Array with integer dtype).

IntArray1D TypeAlias

1-dimensional signed integer array (JAX Array with integer dtype).

IntArray1DLike TypeAlias

Anything that can be cast into a 1-dimensional integer array without loss of data.

IntArrayLike TypeAlias

Anything that can be cast into an N-dimensional Int array without loss of data.

IntLike TypeAlias

Anything that can be cast into an Int without loss of data.

Mode TypeAlias

A mode identifier string (e.g., '0', 'TE0', 'TM1').

Model TypeAlias

A SAX model function that works with either single-mode or multi-mode circuits.

ModelFactory TypeAlias

A SAX model factory function that works with single-mode or multi-mode circuits.

ModelFactoryMM TypeAlias

A keyword-only function that produces any multi-mode model.

ModelFactorySM TypeAlias

A keyword-only function that produces any single-mode model.

ModelMM TypeAlias

A keyword-only function that produces any multi-mode S-matrix type.

ModelSM TypeAlias

A keyword-only function that produces any single-mode S-matrix type.

Models TypeAlias

A collection of model functions, supporting both single-mode and multi-mode.

ModelsMM TypeAlias

A mapping from model names to multi-mode model functions.

ModelsSM TypeAlias

A mapping from model names to single-mode model functions.

Name TypeAlias

A valid Python identifier - contains only letters, numbers, and underscores.

Nets TypeAlias

A list of logical connections between ports.

Placements TypeAlias

A mapping from instance names to their physical placements.

Port TypeAlias

A single-mode port name - must be a valid Python identifier.

PortCombination TypeAlias

A pair of port names, either single-mode or multi-mode format.

PortCombinationMM TypeAlias

A pair of multi-mode port-mode names representing an S-parameter.

PortCombinationSM TypeAlias

A pair of single-mode port names representing an S-parameter.

PortMap TypeAlias

A mapping from port names to matrix indices, supporting both mode types.

PortMapMM TypeAlias

A mapping from multi-mode port-mode names to their matrix indices.

PortMapSM TypeAlias

A mapping from single-mode port names to their matrix indices.

PortMode TypeAlias

A port-mode specification in the format 'port_name@mode_name'.

Ports TypeAlias

A mapping from external circuit ports to internal instance ports.

RecursiveNetlist TypeAlias

A hierarchical netlist containing multiple named circuits.

SCoo TypeAlias

A sparse COO format S-matrix, supporting both single-mode and multi-mode.

SCooMM TypeAlias

A sparse S-matrix in COO format (recommended for internal library use only).

SCooModel TypeAlias

A model function that produces SCoo S-matrices in either mode format.

SCooModelFactory TypeAlias

A model factory that produces SCoo models in either mode format.

SCooModelFactoryMM TypeAlias

A keyword-only function that produces a multi-mode SCoo model.

SCooModelFactorySM TypeAlias

A keyword-only function that produces a single-mode SCoo model.

SCooModelMM TypeAlias

A keyword-only function that produces a multi-mode SCoo S-matrix.

SCooModelSM TypeAlias

A keyword-only function that produces a single-mode SCoo S-matrix.

SCooSM TypeAlias

A sparse S-matrix in COO format (recommended for internal library use only).

SDense TypeAlias

A dense S-matrix representation, supporting both single-mode and multi-mode.

SDenseMM TypeAlias

A dense S-matrix representation.

SDenseModel TypeAlias

A model function that produces SDense S-matrices in either mode format.

SDenseModelFactory TypeAlias

A model factory that produces SDense models in either mode format.

SDenseModelFactoryMM TypeAlias

A keyword-only function that produces a multi-mode SDense model.

SDenseModelFactorySM TypeAlias

A keyword-only function that produces a single-mode SDense model.

SDenseModelMM TypeAlias

A keyword-only function that produces a multi-mode SDense S-matrix.

SDenseModelSM TypeAlias

A keyword-only function that produces a single-mode SDense S-matrix.

SDenseSM TypeAlias

A dense S-matrix representation.

SDict TypeAlias

A dictionary-based sparse S-matrix, supporting both single-mode and multi-mode.

SDictMM TypeAlias

A sparse dictionary-based S-matrix representation.

SDictModel TypeAlias

A model function that produces SDict S-matrices in either mode format.

SDictModelFactory TypeAlias

A model factory that produces SDict models in either mode format.

SDictModelFactoryMM TypeAlias

A keyword-only function that produces a multi-mode SDict model.

SDictModelFactorySM TypeAlias

A keyword-only function that produces a single-mode SDict model.

SDictModelMM TypeAlias

A keyword-only function that produces a multi-mode SDict S-matrix.

SDictModelSM TypeAlias

A keyword-only function that produces a single-mode SDict S-matrix.

SDictSM TypeAlias

A sparse dictionary-based S-matrix representation.

SType TypeAlias

Any S-matrix type (SDict, SDense, SCoo) in single-mode or multi-mode format.

STypeMM TypeAlias

Any S-Matrix type [SDict, SDense, SCOO].

STypeSM TypeAlias

Any S-Matrix type [SDict, SDense, SCOO].

Settings TypeAlias

A (possibly nested) settings mapping for configuring models and circuits.

SettingsValue TypeAlias

Any value that can be stored in a settings dictionary

WL_C float

C-band center wavelength (μm).

WL_C_MAX float

C-band maximum wavelength (μm).

WL_C_MIN float

C-band minimum wavelength (μm).

WL_E float

E-band center wavelength (μm).

WL_E_MAX float

E-band maximum wavelength (μm).

WL_E_MIN float

E-band minimum wavelength (μm).

WL_L float

L-band center wavelength (μm).

WL_L_MAX float

L-band maximum wavelength (μm).

WL_L_MIN float

L-band minimum wavelength (μm).

WL_O float

O-band center wavelength (μm).

WL_O_MAX float

O-band maximum wavelength (μm).

WL_O_MIN float

O-band minimum wavelength (μm).

WL_S float

S-band center wavelength (μm).

WL_S_MAX float

S-band maximum wavelength (μm).

WL_S_MIN float

S-band minimum wavelength (μm).

AnyNetlist module-attribute

Any valid netlist format: flat, recursive, or simplified dictionary.

ArrayLike module-attribute

ArrayLike: TypeAlias = Array | ndarray

Anything that can turn into an array with ndim>=1.

Backend module-attribute

Backend: TypeAlias = Annotated[
    Literal["filipsson_gunnar", "additive", "forward", "klu"], val(val_backend)
]

Available SAX backend algorithms for circuit simulation.

BackendLike module-attribute

BackendLike: TypeAlias = Backend | Literal['default', 'fg']

Backend specification allowing 'default' to use the system default backend.

Bool module-attribute

Bool: TypeAlias = Annotated[bool | bool_, val(val_bool, strict=False)]

Any boolean value (Python bool or NumPy boolean).

BoolArray module-attribute

BoolArray: TypeAlias = Annotated[Array, bool_, val(val_bool_array, strict=False)]

N-dimensional boolean array (JAX Array with boolean dtype).

BoolArrayLike module-attribute

BoolArrayLike: TypeAlias = Annotated[
    ArrayLike | BoolLike, bool_, val(val_bool_array, cast=False)
]

Anything that can be cast into an N-dimensional Bool array without loss of data.

BoolLike module-attribute

BoolLike: TypeAlias = Annotated[bool | bool_, val(val_bool, cast=False)]

Anything that can be cast into a Bool without loss of data.

C_M_S module-attribute

C_M_S: float = 299792458.0

Speed of light in vacuum (m/s).

C_UM_S module-attribute

C_UM_S: float = 1000000.0 * C_M_S

Speed of light in vacuum (μm/s).

Complex module-attribute

Complex: TypeAlias = Annotated[
    complex | complexfloating, val(val_complex, strict=False)
]

Any complex number (Python complex or NumPy complex floating).

ComplexArray module-attribute

ComplexArray: TypeAlias = Annotated[
    Array, complexfloating, val(val_complex_array, strict=False)
]

N-dimensional complex array (JAX Array with complex dtype).

ComplexArray1D module-attribute

ComplexArray1D: TypeAlias = Annotated[
    ArrayLike, complexfloating, 1, val(val_complex_array_1d, strict=False)
]

1-dimensional complex array (JAX Array with complex dtype).

ComplexArray1DLike module-attribute

ComplexArray1DLike: TypeAlias = Annotated[
    ComplexArrayLike, inexact, 1, val(val_complex_array_1d, cast=False)
]

Anything that can be cast into a 1-dimensional complex array without loss of data.

ComplexArrayLike module-attribute

ComplexArrayLike: TypeAlias = Annotated[
    ArrayLike | ComplexLike, inexact, val(val_complex_array, cast=False)
]

Anything that can be cast into an N-dim Complex array without loss of data.

ComplexLike module-attribute

ComplexLike: TypeAlias = Annotated[
    FloatLike | complex | inexact, val(val_complex, cast=False)
]

Anything that can be cast into a Complex without loss of data.

Component module-attribute

Component: TypeAlias = Annotated[str, val(val_name, name='Component')]

The name of a component model (must be a valid Python identifier).

Connections module-attribute

A mapping defining point-to-point connections between instance ports.

DEFAULT_MODE module-attribute

DEFAULT_MODE: str = 'TE'

Default optical mode.

DEFAULT_MODES module-attribute

DEFAULT_MODES: tuple[str, ...] = ('TE', 'TM')

Default multimode configuration.

DEFAULT_WL_STEP module-attribute

DEFAULT_WL_STEP: float = 0.0001

Default wavelength step for array generation (μm).

EPS module-attribute

EPS: float = 1e-12

Small numerical epsilon for determining which s-values can be considered zero.

Float module-attribute

Float: TypeAlias = Annotated[float | floating, val(val_float, strict=False)]

Any floating-point number (Python float or NumPy floating).

FloatArray module-attribute

FloatArray: TypeAlias = Annotated[Array, floating, val(val_float_array, strict=False)]

N-dimensional floating-point array (JAX Array with float dtype).

FloatArray1D module-attribute

FloatArray1D: TypeAlias = Annotated[
    ArrayLike, floating, 1, val(val_float_array_1d, strict=False)
]

1-dimensional floating-point array (JAX Array with float dtype).

FloatArray1DLike module-attribute

FloatArray1DLike: TypeAlias = Annotated[
    FloatArrayLike, floating, 1, val(val_float_array_1d, cast=False)
]

Anything that can be cast into a 1-dimensional float array without loss of data.

FloatArray2D module-attribute

FloatArray2D: TypeAlias = Annotated[
    ArrayLike, floating, 2, val(val_float_array_2d, strict=False)
]

2-dimensional floating-point array (JAX Array with float dtype).

FloatArray2DLike module-attribute

FloatArray2DLike: TypeAlias = Annotated[
    FloatArrayLike, floating, 2, val(val_float_array_2d, cast=False)
]

Anything that can be cast into a 2-dimensional float array without loss of data.

FloatArrayLike module-attribute

FloatArrayLike: TypeAlias = Annotated[
    ArrayLike | FloatLike, floating, val(val_float_array, cast=False)
]

Anything that can be cast into an N-dimensional Float array without loss of data.

FloatLike module-attribute

FloatLike: TypeAlias = Annotated[IntLike | float | floating, val(val_float, cast=False)]

Anything that can be cast into a Float without loss of data.

Instance module-attribute

Instance: TypeAlias = Annotated[_Instance, val(val_instance)]

An component instantiation in a netlist with optional settings and array config.

InstanceName module-attribute

InstanceName: TypeAlias = Annotated[str, val(val_instance_name)]

An instance name allowing dots and angle brackets for hierarchical naming.

InstancePort module-attribute

InstancePort: TypeAlias = Annotated[str, val(val_instance_port)]

An instance port reference in the format 'instance_name,port_name'.

Instances module-attribute

A mapping from instance names to their definitions.

Int module-attribute

Int: TypeAlias = Annotated[int | signedinteger, val(val_int, strict=False)]

Any signed integer (Python int or NumPy signed integer).

IntArray module-attribute

IntArray: TypeAlias = Annotated[Array, signedinteger, val(val_int_array, strict=False)]

N-dimensional signed integer array (JAX Array with integer dtype).

IntArray1D module-attribute

IntArray1D: TypeAlias = Annotated[
    ArrayLike, signedinteger, 1, val(val_int_array_1d, strict=False)
]

1-dimensional signed integer array (JAX Array with integer dtype).

IntArray1DLike module-attribute

IntArray1DLike: TypeAlias = Annotated[
    IntArrayLike, integer, 1, val(val_int_array_1d, cast=False)
]

Anything that can be cast into a 1-dimensional integer array without loss of data.

IntArrayLike module-attribute

IntArrayLike: TypeAlias = Annotated[
    ArrayLike | IntLike, integer, val(val_int_array, cast=False)
]

Anything that can be cast into an N-dimensional Int array without loss of data.

IntLike module-attribute

IntLike: TypeAlias = Annotated[int | integer, val(val_int, cast=False)]

Anything that can be cast into an Int without loss of data.

Mode module-attribute

Mode: TypeAlias = Annotated[str, val(val_mode)]

A mode identifier string (e.g., '0', 'TE0', 'TM1').

Model module-attribute

A SAX model function that works with either single-mode or multi-mode circuits.

ModelFactory module-attribute

A SAX model factory function that works with single-mode or multi-mode circuits.

ModelFactoryMM module-attribute

ModelFactoryMM: TypeAlias = Annotated[
    SDictModelFactoryMM | SDenseModelFactoryMM | SCooModelFactoryMM,
    val(val_model_factory),
]

A keyword-only function that produces any multi-mode model.

ModelFactorySM module-attribute

ModelFactorySM: TypeAlias = Annotated[
    SDictModelFactorySM | SDenseModelFactorySM | SCooModelFactorySM,
    val(val_model_factory),
]

A keyword-only function that produces any single-mode model.

ModelMM module-attribute

ModelMM: TypeAlias = Annotated[
    SDenseModelMM | SCooModelMM | SDictModelMM, val(val_model)
]

A keyword-only function that produces any multi-mode S-matrix type.

ModelSM module-attribute

ModelSM: TypeAlias = Annotated[
    SDictModelSM | SDenseModelSM | SCooModelSM, val(val_model)
]

A keyword-only function that produces any single-mode S-matrix type.

Models module-attribute

A collection of model functions, supporting both single-mode and multi-mode.

ModelsMM module-attribute

ModelsMM: TypeAlias = dict[Name, ModelMM]

A mapping from model names to multi-mode model functions.

ModelsSM module-attribute

ModelsSM: TypeAlias = dict[Name, ModelSM]

A mapping from model names to single-mode model functions.

Name module-attribute

Name: TypeAlias = Annotated[str, val(val_name)]

A valid Python identifier - contains only letters, numbers, and underscores.

Nets module-attribute

Nets: TypeAlias = list[Net]

A list of logical connections between ports.

Placements module-attribute

A mapping from instance names to their physical placements.

Port module-attribute

Port: TypeAlias = Annotated[str, val(val_port)]

A single-mode port name - must be a valid Python identifier.

PortCombination module-attribute

A pair of port names, either single-mode or multi-mode format.

PortCombinationMM module-attribute

PortCombinationMM: TypeAlias = tuple[PortMode, PortMode]

A pair of multi-mode port-mode names representing an S-parameter.

PortCombinationSM module-attribute

PortCombinationSM: TypeAlias = tuple[Port, Port]

A pair of single-mode port names representing an S-parameter.

PortMap module-attribute

A mapping from port names to matrix indices, supporting both mode types.

PortMapMM module-attribute

PortMapMM: TypeAlias = dict[PortMode, int]

A mapping from multi-mode port-mode names to their matrix indices.

PortMapSM module-attribute

PortMapSM: TypeAlias = dict[Port, int]

A mapping from single-mode port names to their matrix indices.

PortMode module-attribute

PortMode: TypeAlias = Annotated[str, val(val_port_mode)]

A port-mode specification in the format 'port_name@mode_name'.

Ports module-attribute

Ports: TypeAlias = Annotated[dict[Port, InstancePort], val(val_ports)]

A mapping from external circuit ports to internal instance ports.

RecursiveNetlist module-attribute

RecursiveNetlist: TypeAlias = dict[Name, Netlist]

A hierarchical netlist containing multiple named circuits.

SCoo module-attribute

A sparse COO format S-matrix, supporting both single-mode and multi-mode.

SCooMM module-attribute

A sparse S-matrix in COO format (recommended for internal library use only).

An SCoo is a sparse matrix based representation of an S-matrix consisting of three arrays and a port map. The three arrays represent the input port indices [int], output port indices [int] and the S-matrix values [ComplexFloat] of the sparse matrix. The port map maps a port name [str] to a port index [int].

Only these four arrays together and in this specific order are considered a valid SCoo representation!

Examples:

Creating an `SCooMM':

Si = jnp.arange(3, dtype=int)
Sj = jnp.array([0, 1, 0], dtype=int)
Sx = jnp.array([3.0, 4.0, 1.0])
port_map = {"in0": 0, "in1": 2, "out0": 1}
scoo: sax.SCoo = (Si, Sj, Sx, port_map)
Note

This representation is only recommended for internal library use. Please don't write user-facing code using this representation.

SCooModel module-attribute

A model function that produces SCoo S-matrices in either mode format.

SCooModelFactory module-attribute

A model factory that produces SCoo models in either mode format.

SCooModelFactoryMM module-attribute

SCooModelFactoryMM: TypeAlias = Annotated[
    Callable[..., SCooModelMM], val(val_model_factory)
]

A keyword-only function that produces a multi-mode SCoo model.

SCooModelFactorySM module-attribute

SCooModelFactorySM: TypeAlias = Annotated[
    Callable[..., SCooModelSM], val(val_model_factory)
]

A keyword-only function that produces a single-mode SCoo model.

SCooModelMM module-attribute

SCooModelMM: TypeAlias = Annotated[Callable[..., SCooMM], val(val_model)]

A keyword-only function that produces a multi-mode SCoo S-matrix.

SCooModelSM module-attribute

SCooModelSM: TypeAlias = Annotated[Callable[..., SCooSM], val(val_model)]

A keyword-only function that produces a single-mode SCoo S-matrix.

SCooSM module-attribute

A sparse S-matrix in COO format (recommended for internal library use only).

An SCoo is a sparse matrix based representation of an S-matrix consisting of three arrays and a port map. The three arrays represent the input port indices [int], output port indices [int] and the S-matrix values [ComplexFloat] of the sparse matrix. The port map maps a port name [str] to a port index [int].

Only these four arrays together and in this specific order are considered a valid SCoo representation!

Examples:

Creating an SCooSM:

Si = jnp.arange(3, dtype=int)
Sj = jnp.array([0, 1, 0], dtype=int)
Sx = jnp.array([3.0, 4.0, 1.0])
port_map = {"in0": 0, "in1": 2, "out0": 1}
scoo: sax.SCoo = (Si, Sj, Sx, port_map)
Note

This representation is only recommended for internal library use. Please don't write user-facing code using this representation.

SDense module-attribute

A dense S-matrix representation, supporting both single-mode and multi-mode.

SDenseMM module-attribute

A dense S-matrix representation.

S-matrix (2D array) or multidimensional batched S-matrix (N+2)-D array with a port map. If (N+2)-D array then the S-matrix dimensions are the last two.

Examples:

Creating an SDenseMM:

Sd = jnp.arange(9, dtype=float).reshape(3, 3)
port_map = {"in0": 0, "in1": 2, "out0": 1}
sdense = Sd, port_map

SDenseModel module-attribute

A model function that produces SDense S-matrices in either mode format.

SDenseModelFactory module-attribute

A model factory that produces SDense models in either mode format.

SDenseModelFactoryMM module-attribute

SDenseModelFactoryMM: TypeAlias = Annotated[
    Callable[..., SDenseModelMM], val(val_model_factory)
]

A keyword-only function that produces a multi-mode SDense model.

SDenseModelFactorySM module-attribute

SDenseModelFactorySM: TypeAlias = Annotated[
    Callable[..., SDenseModelSM], val(val_model_factory)
]

A keyword-only function that produces a single-mode SDense model.

SDenseModelMM module-attribute

SDenseModelMM: TypeAlias = Annotated[Callable[..., SDenseMM], val(val_model)]

A keyword-only function that produces a multi-mode SDense S-matrix.

SDenseModelSM module-attribute

SDenseModelSM: TypeAlias = Annotated[Callable[..., SDenseSM], val(val_model)]

A keyword-only function that produces a single-mode SDense S-matrix.

SDenseSM module-attribute

A dense S-matrix representation.

S-matrix (2D array) or multidimensional batched S-matrix (N+2)-D array with a port map. If (N+2)-D array then the S-matrix dimensions are the last two.

Examples:

Creating an SDenseSM:

Sd = jnp.arange(9, dtype=float).reshape(3, 3)
port_map = {"in0": 0, "in1": 2, "out0": 1}
sdense = Sd, port_map

SDict module-attribute

A dictionary-based sparse S-matrix, supporting both single-mode and multi-mode.

SDictMM module-attribute

A sparse dictionary-based S-matrix representation.

A mapping from a port combination to an S-parameter or an array of S-parameters.

Examples:

Creating an SDictMM:

sdict: sax.SDict = {
    ("in0", "out0"): 3.0,
}

SDictModel module-attribute

A model function that produces SDict S-matrices in either mode format.

SDictModelFactory module-attribute

A model factory that produces SDict models in either mode format.

SDictModelFactoryMM module-attribute

SDictModelFactoryMM: TypeAlias = Annotated[
    Callable[..., SDictModelMM], val(val_model_factory)
]

A keyword-only function that produces a multi-mode SDict model.

SDictModelFactorySM module-attribute

SDictModelFactorySM: TypeAlias = Annotated[
    Callable[..., SDictModelSM], val(val_model_factory)
]

A keyword-only function that produces a single-mode SDict model.

SDictModelMM module-attribute

SDictModelMM: TypeAlias = Annotated[Callable[..., SDictMM], val(val_model)]

A keyword-only function that produces a multi-mode SDict S-matrix.

SDictModelSM module-attribute

SDictModelSM: TypeAlias = Annotated[Callable[..., SDictSM], val(val_model)]

A keyword-only function that produces a single-mode SDict S-matrix.

SDictSM module-attribute

A sparse dictionary-based S-matrix representation.

A mapping from a port combination to an S-parameter or an array of S-parameters.

Examples:

Creating an SDictSM:

sdict: sax.SDict = {
    ("in0", "out0"): 3.0,
}

SType module-attribute

Any S-matrix type (SDict, SDense, SCoo) in single-mode or multi-mode format.

STypeMM module-attribute

Any S-Matrix type [SDict, SDense, SCOO].

STypeSM module-attribute

Any S-Matrix type [SDict, SDense, SCOO].

Settings module-attribute

Settings: TypeAlias = Annotated[dict[str, SettingsValue], val(val_settings)]

A (possibly nested) settings mapping for configuring models and circuits.

Settings provide a hierarchical way to configure SAX models and circuits. Top-level keys often correspond to global parameters or instance names, while nested dictionaries contain model-specific parameters.

Examples:

Define settings for a model or circuit:

mzi_settings: sax.Settings = {
    "wl": 1.5,  # global wavelength setting
    "lft": {"coupling": 0.5},  # settings for the left coupler
    "top": {"neff": 3.4},      # settings for the top waveguide
    "rgt": {"coupling": 0.3},  # settings for the right coupler
}

array_settings: sax.Settings = {
    "wavelengths": [1.50, 1.51, 1.52, 1.53, 1.54, 1.55],
    "temperature": 25.0,
    "component_settings": {
        "loss_db_per_cm": 0.1,
        "group_index": 4.2
    }
}

SettingsValue module-attribute

SettingsValue: TypeAlias = Annotated[Any, val(val_settings_value)]

Any value that can be stored in a settings dictionary

WL_C module-attribute

WL_C: float = 0.5 * WL_C_MIN + WL_C_MAX

C-band center wavelength (μm).

WL_C_MAX module-attribute

WL_C_MAX: float = 1.565

C-band maximum wavelength (μm).

WL_C_MIN module-attribute

WL_C_MIN: float = 1.53

C-band minimum wavelength (μm).

WL_E module-attribute

WL_E: float = 0.5 * WL_E_MIN + WL_E_MAX

E-band center wavelength (μm).

WL_E_MAX module-attribute

WL_E_MAX: float = 1.46

E-band maximum wavelength (μm).

WL_E_MIN module-attribute

WL_E_MIN: float = 1.36

E-band minimum wavelength (μm).

WL_L module-attribute

WL_L: float = 0.5 * WL_L_MIN + WL_L_MAX

L-band center wavelength (μm).

WL_L_MAX module-attribute

WL_L_MAX: float = 1.625

L-band maximum wavelength (μm).

WL_L_MIN module-attribute

WL_L_MIN: float = 1.565

L-band minimum wavelength (μm).

WL_O module-attribute

WL_O: float = 0.5 * WL_O_MIN + WL_O_MAX

O-band center wavelength (μm).

WL_O_MAX module-attribute

WL_O_MAX: float = 1.36

O-band maximum wavelength (μm).

WL_O_MIN module-attribute

WL_O_MIN: float = 1.26

O-band minimum wavelength (μm).

WL_S module-attribute

WL_S: float = 0.5 * WL_S_MIN + WL_S_MAX

S-band center wavelength (μm).

WL_S_MAX module-attribute

WL_S_MAX: float = 1.53

S-band maximum wavelength (μm).

WL_S_MIN module-attribute

WL_S_MIN: float = 1.46

S-band minimum wavelength (μm).

CircuitInfo

Bases: NamedTuple

Information about a SAX circuit function.

This class contains metadata about the circuit structure, models used, and backend selected during circuit compilation.

Attributes:

Name Type Description
dag DiGraph

The directed acyclic graph representing the circuit topology.

models dict[Name, Model]

Dictionary mapping instance names to their model functions.

backend Backend

The backend algorithm used for circuit simulation.

Examples:

Usage of CircuitInfo after creating a circuit function:

circuit_func, info = sax.circuit(netlist, models, backend="klu")

print(f"Circuit uses {len(info.models)} models")
print(f"Backend: {info.backend}")
print(f"Circuit has {len(info.dag.nodes)} instances")

backend instance-attribute

backend: Backend

The backend algorithm used for circuit simulation.

dag instance-attribute

dag: DiGraph

The circuit topology as a directed acyclic graph.

models instance-attribute

models: dict[Name, Model]

Mapping from instance names to their model functions.

IOLike

Bases: Protocol

Protocol for file-like objects that can be read from.

This protocol defines the interface for objects that can be used for reading text data, such as file objects or StringIO objects.

Methods:

Name Description
read

Read and return up to size characters.

read

read(size: int | None = -1) -> str

Read and return up to size characters.

Parameters:

Name Type Description Default
size int | None

Maximum number of characters to read. If None or -1, read all.

-1

Returns:

Type Description
str

The read characters as a string.

Source code in src/sax/saxtypes/core.py
def read(self, size: int | None = -1, /) -> str:
    """Read and return up to size characters.

    Args:
        size: Maximum number of characters to read. If None or -1, read all.

    Returns:
        The read characters as a string.
    """
    ...

Net

Bases: TypedDict

A logical connection between two ports.

Represents a point-to-point connection with optional metadata.

Attributes:

Name Type Description
p1 str

First port in the connection.

p2 str

Second port in the connection.

settings NotRequired[dict]

Optional connection settings.

name NotRequired[str | None]

Optional connection name.

Netlist

Bases: TypedDict

A complete netlist definition for an optical circuit.

Contains all information needed to define a circuit: instances, connections, external ports, and optional placement/settings.

Attributes:

Name Type Description
instances Instances

The component instances in the circuit.

connections NotRequired[Connections]

Point-to-point connections between instances.

ports Ports

Mapping of external ports to internal instance ports.

nets NotRequired[Nets]

Alternative connection specification as a list.

placements NotRequired[Placements]

Physical placement information for instances.

settings NotRequired[Settings]

Global circuit settings.

Normalization

Bases: NamedTuple

Normalization parameters for an array.

Contains the mean and standard deviation values needed to normalize and denormalize arrays. Typically used for machine learning preprocessing.

Attributes:

Name Type Description
mean ComplexArray

Mean values for normalization.

std ComplexArray

Standard deviation values for normalization.

Placement

Bases: TypedDict

Physical placement information for an instance.

Defines the position, orientation, and constraints for placing an instance in physical layout coordinates.

Attributes:

Name Type Description
x str | float

X coordinate position.

y str | float

Y coordinate position.

dx NotRequired[str | float]

Optional X offset.

dy NotRequired[str | float]

Optional Y offset.

rotation NotRequired[float]

Optional rotation angle in degrees.

mirror NotRequired[bool]

Optional mirroring flag.

xmin NotRequired[str | float | None]

Optional minimum X constraint.

xmax NotRequired[str | float | None]

Optional maximum X constraint.

ymin NotRequired[str | float | None]

Optional minimum Y constraint.

ymax NotRequired[str | float | None]

Optional maximum Y constraint.

port NotRequired[str | _PortPlacement | None]

Optional port anchor specification.

PortNamer

PortNamer(
    num_inputs: int, num_outputs: int, strategy: PortNamingStrategy | None = None
)

Port naming class to encapsulate port naming logic.

This class provides a unified interface for generating consistent port names across SAX model functions. It supports different naming strategies and handles the mapping between logical port indices and string names.

The PortNamer automatically applies the current global naming strategy or allows override with a specific strategy. It provides both attribute-based access (e.g., p.in0, p.out1) and index-based access (e.g., p[0], p[1]).

Attributes:

Name Type Description
num_inputs

Number of input ports.

num_outputs

Number of output ports.

num_ports

Total number of ports (inputs + outputs).

strategy

Port naming strategy being used.

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
strategy PortNamingStrategy | None

Optional port naming strategy override. If None, uses the current global strategy from get_port_naming_strategy().

None

Examples:

Create a 2x2 device (e.g., directional coupler):

p = PortNamer(2, 2)
print(p.in0, p.in1, p.out0, p.out1)  # inout strategy

Create a 1x2 device (e.g., splitter):

p = PortNamer(1, 2)
print(p.in0, p.out0, p.out1)

Force optical naming:

p = PortNamer(2, 2, strategy="optical")
print(p.o1, p.o2, p.o3, p.o4)

Methods:

Name Description
is_input_port

Check if a port is an input port.

is_input_port_idx

Check if a port index is an input port.

is_output_port

Check if a port is an output port.

is_output_port_idx

Check if a port index is an output port.

Source code in src/sax/ports.py
def __init__(
    self,
    num_inputs: int,
    num_outputs: int,
    strategy: PortNamingStrategy | None = None,
) -> None:
    """Initialize the PortNamer with the number of ports.

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        strategy: Optional port naming strategy override. If None, uses
            the current global strategy from get_port_naming_strategy().

    Examples:
        Create a 2x2 device (e.g., directional coupler):

        ```python
        p = PortNamer(2, 2)
        print(p.in0, p.in1, p.out0, p.out1)  # inout strategy
        ```

        Create a 1x2 device (e.g., splitter):

        ```python
        p = PortNamer(1, 2)
        print(p.in0, p.out0, p.out1)
        ```

        Force optical naming:

        ```python
        p = PortNamer(2, 2, strategy="optical")
        print(p.o1, p.o2, p.o3, p.o4)
        ```
    """
    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.num_ports = num_inputs + num_outputs
    self.strategy = (strategy or get_port_naming_strategy()).lower()

is_input_port

is_input_port(port: Name) -> bool

Check if a port is an input port.

Parameters:

Name Type Description Default
port Name

Port name to check.

required

Returns:

Type Description
bool

True if the port is an input port, False otherwise.

Examples:

Check port direction:

p = PortNamer(2, 2)
print(p.is_input_port("in0"))  # True
print(p.is_input_port("out0"))  # False
Source code in src/sax/ports.py
def is_input_port(self, port: sax.Name) -> bool:
    """Check if a port is an input port.

    Args:
        port: Port name to check.

    Returns:
        True if the port is an input port, False otherwise.

    Examples:
        Check port direction:

        ```python
        p = PortNamer(2, 2)
        print(p.is_input_port("in0"))  # True
        print(p.is_input_port("out0"))  # False
        ```
    """
    port = port.lower()
    if port.startswith("in"):
        return True
    idx = int(re.sub("[a-zA-Z_-]", "", port))
    return idx < self.num_inputs

is_input_port_idx

is_input_port_idx(idx: int) -> bool

Check if a port index is an input port.

Parameters:

Name Type Description Default
idx int

Port index to check.

required

Returns:

Type Description
bool

True if the port index corresponds to an input port, False otherwise.

Examples:

Check port direction by index:

p = PortNamer(2, 2)
print(p.is_input_port_idx(0))  # True (first input)
print(p.is_input_port_idx(2))  # False (first output)
Source code in src/sax/ports.py
def is_input_port_idx(self, idx: int) -> bool:
    """Check if a port index is an input port.

    Args:
        idx: Port index to check.

    Returns:
        True if the port index corresponds to an input port, False otherwise.

    Examples:
        Check port direction by index:

        ```python
        p = PortNamer(2, 2)
        print(p.is_input_port_idx(0))  # True (first input)
        print(p.is_input_port_idx(2))  # False (first output)
        ```
    """
    return self.is_input_port(self[idx])

is_output_port

is_output_port(port: Name) -> bool

Check if a port is an output port.

Parameters:

Name Type Description Default
port Name

Port name to check.

required

Returns:

Type Description
bool

True if the port is an output port, False otherwise.

Examples:

Check port direction:

p = PortNamer(2, 2)
print(p.is_output_port("out0"))  # True
print(p.is_output_port("in0"))  # False
Source code in src/sax/ports.py
def is_output_port(self, port: sax.Name) -> bool:
    """Check if a port is an output port.

    Args:
        port: Port name to check.

    Returns:
        True if the port is an output port, False otherwise.

    Examples:
        Check port direction:

        ```python
        p = PortNamer(2, 2)
        print(p.is_output_port("out0"))  # True
        print(p.is_output_port("in0"))  # False
        ```
    """
    return not self.is_input_port(port)

is_output_port_idx

is_output_port_idx(idx: int) -> bool

Check if a port index is an output port.

Parameters:

Name Type Description Default
idx int

Port index to check.

required

Returns:

Type Description
bool

True if the port index corresponds to an output port, False otherwise.

Examples:

Check port direction by index:

p = PortNamer(2, 2)
print(p.is_output_port_idx(2))  # True (first output)
print(p.is_output_port_idx(0))  # False (first input)
Source code in src/sax/ports.py
def is_output_port_idx(self, idx: int) -> bool:
    """Check if a port index is an output port.

    Args:
        idx: Port index to check.

    Returns:
        True if the port index corresponds to an output port, False otherwise.

    Examples:
        Check port direction by index:

        ```python
        p = PortNamer(2, 2)
        print(p.is_output_port_idx(2))  # True (first output)
        print(p.is_output_port_idx(0))  # False (first input)
        ```
    """
    return self.is_output_port(self[idx])

try_into

Optional type converter utility that returns None on failure.

Use this class with bracket notation to safely attempt type conversion. Returns None instead of raising errors when conversion fails.

Examples:

Validate and cast objects to SAX or python types safely:

import sax.saxtypes as sxt

# Safe conversion attempts
result = sxt.try_into[sxt.Float]("3.14")  # 3.14
result = sxt.try_into[sxt.Float]("invalid")  # None

# Use in conditional expressions
if (value := sxt.try_into[sxt.Complex]("1+2j")) is not None:
    print(f"Converted: {value}")  # (1+2j)

# Handle arrays safely
arr = sxt.try_into[sxt.FloatArray]([[1, 2], [3, 4]])
if arr is not None:
    print(f"Array shape: {arr.shape}")  # (2, 2)

analyze_circuit_additive

analyze_circuit_additive(
    analyzed_instances: dict[InstanceName, SDict],
    connections: Connections,
    ports: Ports,
) -> Any

Analyze circuit topology for the additive backend.

Prepares the circuit connection information for the additive backend evaluation. This backend uses graph theory to find all possible paths between circuit ports and sums their contributions.

Parameters:

Name Type Description Default
analyzed_instances dict[InstanceName, SDict]

Instance S-matrices from analyze_instances_additive. Not used in this analysis step but required for interface consistency.

required
connections Connections

Dictionary mapping instance ports to each other, defining internal circuit connections.

required
ports Ports

Dictionary mapping external port names to instance ports.

required

Returns:

Type Description
Any

Tuple containing connections and ports information for circuit evaluation.

Example
connections = {"wg1,out": "dc1,in1", "dc1,out1": "wg2,in"}
ports = {"in": "wg1,in", "out": "wg2,out"}
analyzed = analyze_circuit_additive(analyzed_instances, connections, ports)
Source code in src/sax/backends/additive.py
def analyze_circuit_additive(
    analyzed_instances: dict[sax.InstanceName, sax.SDict],  # noqa: ARG001
    connections: sax.Connections,
    ports: sax.Ports,
) -> Any:  # noqa: ANN401
    """Analyze circuit topology for the additive backend.

    Prepares the circuit connection information for the additive backend
    evaluation. This backend uses graph theory to find all possible paths
    between circuit ports and sums their contributions.

    Args:
        analyzed_instances: Instance S-matrices from analyze_instances_additive.
            Not used in this analysis step but required for interface consistency.
        connections: Dictionary mapping instance ports to each other, defining
            internal circuit connections.
        ports: Dictionary mapping external port names to instance ports.

    Returns:
        Tuple containing connections and ports information for circuit evaluation.

    Example:
        ```python
        connections = {"wg1,out": "dc1,in1", "dc1,out1": "wg2,in"}
        ports = {"in": "wg1,in", "out": "wg2,out"}
        analyzed = analyze_circuit_additive(analyzed_instances, connections, ports)
        ```
    """
    return connections, ports

analyze_circuit_fg

analyze_circuit_fg(
    analyzed_instances: dict[str, SDict], connections: Connections, ports: Ports
) -> Any

Analyze circuit topology for the Filipsson-Gunnar backend.

Prepares the circuit connection information for the Filipsson-Gunnar backend evaluation. This implementation currently skips detailed analysis and passes the connection information directly to the evaluation phase.

Parameters:

Name Type Description Default
analyzed_instances dict[str, SDict]

Instance S-matrices from analyze_instances_fg. Not used in this analysis step but required for interface consistency.

required
connections Connections

Dictionary mapping instance ports to each other, defining internal circuit connections.

required
ports Ports

Dictionary mapping external port names to instance ports.

required

Returns:

Type Description
Any

Tuple containing connections and ports information for circuit evaluation.

Example
connections = {"wg1,out": "dc1,in1", "dc1,out1": "wg2,in"}
ports = {"in": "wg1,in", "out": "wg2,out"}
analyzed = analyze_circuit_fg(analyzed_instances, connections, ports)
Source code in src/sax/backends/filipsson_gunnar.py
def analyze_circuit_fg(
    analyzed_instances: dict[str, sax.SDict],  # noqa: ARG001
    connections: sax.Connections,
    ports: sax.Ports,
) -> Any:  # noqa: ANN401
    """Analyze circuit topology for the Filipsson-Gunnar backend.

    Prepares the circuit connection information for the Filipsson-Gunnar backend
    evaluation. This implementation currently skips detailed analysis and passes
    the connection information directly to the evaluation phase.

    Args:
        analyzed_instances: Instance S-matrices from analyze_instances_fg.
            Not used in this analysis step but required for interface consistency.
        connections: Dictionary mapping instance ports to each other, defining
            internal circuit connections.
        ports: Dictionary mapping external port names to instance ports.

    Returns:
        Tuple containing connections and ports information for circuit evaluation.

    Example:
        ```python
        connections = {"wg1,out": "dc1,in1", "dc1,out1": "wg2,in"}
        ports = {"in": "wg1,in", "out": "wg2,out"}
        analyzed = analyze_circuit_fg(analyzed_instances, connections, ports)
        ```
    """
    return connections, ports  # skip analysis for now

analyze_circuit_forward

analyze_circuit_forward(
    analyzed_instances: dict[InstanceName, SDict],
    connections: Connections,
    ports: Ports,
) -> Any

Analyze circuit topology for the forward-only backend.

Prepares the circuit connection information for the forward-only backend evaluation. This backend assumes unidirectional signal flow and does not account for reflections or bidirectional coupling.

Parameters:

Name Type Description Default
analyzed_instances dict[InstanceName, SDict]

Instance S-matrices from analyze_instances_forward. Not used in this analysis step but required for interface consistency.

required
connections Connections

Dictionary mapping instance ports to each other, defining internal circuit connections.

required
ports Ports

Dictionary mapping external port names to instance ports.

required

Returns:

Type Description
Any

Tuple containing connections and ports information for circuit evaluation.

Example
connections = {"wg1,out": "amp1,in", "amp1,out": "wg2,in"}
ports = {"in": "wg1,in", "out": "wg2,out"}
analyzed = analyze_circuit_forward(analyzed_instances, connections, ports)
Source code in src/sax/backends/forward_only.py
def analyze_circuit_forward(
    analyzed_instances: dict[sax.InstanceName, sax.SDict],  # noqa: ARG001
    connections: sax.Connections,
    ports: sax.Ports,
) -> Any:  # noqa: ANN401
    """Analyze circuit topology for the forward-only backend.

    Prepares the circuit connection information for the forward-only backend
    evaluation. This backend assumes unidirectional signal flow and does not
    account for reflections or bidirectional coupling.

    Args:
        analyzed_instances: Instance S-matrices from analyze_instances_forward.
            Not used in this analysis step but required for interface consistency.
        connections: Dictionary mapping instance ports to each other, defining
            internal circuit connections.
        ports: Dictionary mapping external port names to instance ports.

    Returns:
        Tuple containing connections and ports information for circuit evaluation.

    Example:
        ```python
        connections = {"wg1,out": "amp1,in", "amp1,out": "wg2,in"}
        ports = {"in": "wg1,in", "out": "wg2,out"}
        analyzed = analyze_circuit_forward(analyzed_instances, connections, ports)
        ```
    """
    return connections, ports

analyze_instances_additive

analyze_instances_additive(
    instances: Instances, models: Models
) -> dict[InstanceName, SDict]

Analyze circuit instances for the additive backend.

Prepares instance S-matrices for the additive backend by converting all component models to SDict format. The additive backend uses a graph-based approach with path finding to compute circuit responses.

Parameters:

Name Type Description Default
instances Instances

Dictionary mapping instance names to instance definitions containing component names and settings.

required
models Models

Dictionary mapping component names to their model functions.

required

Returns:

Type Description
dict[InstanceName, SDict]

Dictionary mapping instance names to their S-matrices in SDict format.

Example
instances = {
    "wg1": {"component": "waveguide", "settings": {"length": 10.0}},
    "dc1": {"component": "coupler", "settings": {"coupling": 0.1}},
}
models = {"waveguide": waveguide_model, "coupler": coupler_model}
analyzed = analyze_instances_additive(instances, models)
Source code in src/sax/backends/additive.py
def analyze_instances_additive(
    instances: sax.Instances,
    models: sax.Models,
) -> dict[sax.InstanceName, sax.SDict]:
    """Analyze circuit instances for the additive backend.

    Prepares instance S-matrices for the additive backend by converting all
    component models to SDict format. The additive backend uses a graph-based
    approach with path finding to compute circuit responses.

    Args:
        instances: Dictionary mapping instance names to instance definitions
            containing component names and settings.
        models: Dictionary mapping component names to their model functions.

    Returns:
        Dictionary mapping instance names to their S-matrices in SDict format.

    Example:
        ```python
        instances = {
            "wg1": {"component": "waveguide", "settings": {"length": 10.0}},
            "dc1": {"component": "coupler", "settings": {"coupling": 0.1}},
        }
        models = {"waveguide": waveguide_model, "coupler": coupler_model}
        analyzed = analyze_instances_additive(instances, models)
        ```
    """
    instances = sax.into[sax.Instances](instances)
    models = sax.into[sax.Models](models)
    model_names = set()
    for i in instances.values():
        model_names.add(i["component"])
    dummy_models = {k: sax.sdict(models[k]()) for k in model_names}
    dummy_instances = {}
    for k, i in instances.items():
        dummy_instances[k] = dummy_models[i["component"]]
    return dummy_instances

analyze_instances_fg

analyze_instances_fg(instances: Instances, models: Models) -> dict[InstanceName, SDict]

Analyze circuit instances for the Filipsson-Gunnar backend.

Prepares instance S-matrices for the Filipsson-Gunnar backend by converting all component models to SDict format. This backend implements the classic Filipsson-Gunnar algorithm for S-matrix interconnection of multiports.

Parameters:

Name Type Description Default
instances Instances

Dictionary mapping instance names to instance definitions containing component names and settings.

required
models Models

Dictionary mapping component names to their model functions.

required

Returns:

Type Description
dict[InstanceName, SDict]

Dictionary mapping instance names to their S-matrices in SDict format.

Note

The Filipsson-Gunnar algorithm is a systematic method for computing the overall S-matrix of interconnected multiport networks described in: Filipsson, Gunnar. "A new general computer algorithm for S-matrix calculation of interconnected multiports." 11th European Microwave Conference. IEEE, 1981.

Example
instances = {
    "wg1": {"component": "waveguide", "settings": {"length": 10.0}},
    "dc1": {"component": "coupler", "settings": {"coupling": 0.1}},
}
models = {"waveguide": waveguide_model, "coupler": coupler_model}
analyzed = analyze_instances_fg(instances, models)
Source code in src/sax/backends/filipsson_gunnar.py
def analyze_instances_fg(
    instances: sax.Instances,
    models: sax.Models,
) -> dict[sax.InstanceName, sax.SDict]:
    """Analyze circuit instances for the Filipsson-Gunnar backend.

    Prepares instance S-matrices for the Filipsson-Gunnar backend by converting
    all component models to SDict format. This backend implements the classic
    Filipsson-Gunnar algorithm for S-matrix interconnection of multiports.

    Args:
        instances: Dictionary mapping instance names to instance definitions
            containing component names and settings.
        models: Dictionary mapping component names to their model functions.

    Returns:
        Dictionary mapping instance names to their S-matrices in SDict format.

    Note:
        The Filipsson-Gunnar algorithm is a systematic method for computing
        the overall S-matrix of interconnected multiport networks described in:
        Filipsson, Gunnar. "A new general computer algorithm for S-matrix
        calculation of interconnected multiports." 11th European Microwave
        Conference. IEEE, 1981.

    Example:
        ```python
        instances = {
            "wg1": {"component": "waveguide", "settings": {"length": 10.0}},
            "dc1": {"component": "coupler", "settings": {"coupling": 0.1}},
        }
        models = {"waveguide": waveguide_model, "coupler": coupler_model}
        analyzed = analyze_instances_fg(instances, models)
        ```
    """
    instances = sax.into[sax.Instances](instances)
    models = sax.into[sax.Models](models)
    model_names = set()
    for i in instances.values():
        model_names.add(i["component"])
    dummy_models = {k: sax.sdict(models[k]()) for k in model_names}
    dummy_instances = {}
    for k, i in instances.items():
        dummy_instances[k] = dummy_models[i["component"]]
    return dummy_instances

analyze_instances_forward

analyze_instances_forward(
    instances: Instances, models: Models
) -> dict[InstanceName, SCoo]

Analyze circuit instances for the forward-only backend.

Prepares instance S-matrices for the forward-only backend by converting all component models to SCoo (coordinate) format. This backend is specialized for feed-forward circuits without feedback loops.

Parameters:

Name Type Description Default
instances Instances

Dictionary mapping instance names to instance definitions containing component names and settings.

required
models Models

Dictionary mapping component names to their model functions.

required

Returns:

Type Description
dict[InstanceName, SCoo]

Dictionary mapping instance names to their S-matrices in SCoo format.

Note

The forward-only backend is designed for circuits with unidirectional signal flow and no feedback paths. It uses a simplified approach that may not be accurate for circuits with reflections or loops.

Example
instances = {
    "wg1": {"component": "waveguide", "settings": {"length": 10.0}},
    "amp1": {"component": "amplifier", "settings": {"gain": 20.0}},
}
models = {"waveguide": waveguide_model, "amplifier": amplifier_model}
analyzed = analyze_instances_forward(instances, models)
Source code in src/sax/backends/forward_only.py
def analyze_instances_forward(
    instances: sax.Instances,
    models: sax.Models,
) -> dict[sax.InstanceName, sax.SCoo]:
    """Analyze circuit instances for the forward-only backend.

    Prepares instance S-matrices for the forward-only backend by converting all
    component models to SCoo (coordinate) format. This backend is specialized
    for feed-forward circuits without feedback loops.

    Args:
        instances: Dictionary mapping instance names to instance definitions
            containing component names and settings.
        models: Dictionary mapping component names to their model functions.

    Returns:
        Dictionary mapping instance names to their S-matrices in SCoo format.

    Note:
        The forward-only backend is designed for circuits with unidirectional
        signal flow and no feedback paths. It uses a simplified approach that
        may not be accurate for circuits with reflections or loops.

    Example:
        ```python
        instances = {
            "wg1": {"component": "waveguide", "settings": {"length": 10.0}},
            "amp1": {"component": "amplifier", "settings": {"gain": 20.0}},
        }
        models = {"waveguide": waveguide_model, "amplifier": amplifier_model}
        analyzed = analyze_instances_forward(instances, models)
        ```
    """
    instances = sax.into[sax.Instances](instances)
    models = sax.into[sax.Models](models)
    model_names = set()
    for i in instances.values():
        model_names.add(i["component"])
    dummy_models = {k: sax.scoo(models[k]()) for k in model_names}
    dummy_instances = {}
    for k, i in instances.items():
        dummy_instances[k] = dummy_models[i["component"]]
    return dummy_instances

block_diag

block_diag(*arrs: Array) -> Array

Create block diagonal matrix with arbitrary batch dimensions.

Constructs a block diagonal matrix from square input matrices. All input arrays must have the same batch dimensions and be square in their last two dimensions.

Parameters:

Name Type Description Default
*arrs Array

Square matrices to place on the block diagonal. All must have matching batch dimensions.

()

Returns:

Type Description
Array

Block diagonal matrix with batch dimensions preserved.

Raises:

Type Description
ValueError

If batch dimensions don't match or matrices aren't square.

Example
import jax.numpy as jnp

A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = block_diag(A, B)
# Result: [[1, 2, 0, 0],
#          [3, 4, 0, 0],
#          [0, 0, 5, 6],
#          [0, 0, 7, 8]]
Source code in src/sax/s.py
def block_diag(*arrs: Array) -> Array:
    """Create block diagonal matrix with arbitrary batch dimensions.

    Constructs a block diagonal matrix from square input matrices. All input
    arrays must have the same batch dimensions and be square in their last
    two dimensions.

    Args:
        *arrs: Square matrices to place on the block diagonal. All must have
            matching batch dimensions.

    Returns:
        Block diagonal matrix with batch dimensions preserved.

    Raises:
        ValueError: If batch dimensions don't match or matrices aren't square.

    Example:
        ```python
        import jax.numpy as jnp

        A = jnp.array([[1, 2], [3, 4]])
        B = jnp.array([[5, 6], [7, 8]])
        C = block_diag(A, B)
        # Result: [[1, 2, 0, 0],
        #          [3, 4, 0, 0],
        #          [0, 0, 5, 6],
        #          [0, 0, 7, 8]]
        ```
    """
    batch_shape = arrs[0].shape[:-2]

    N = 0
    for arr in arrs:
        if batch_shape != arr.shape[:-2]:
            msg = "Batch dimensions for given arrays don't match."
            raise ValueError(msg)
        m, n = arr.shape[-2:]
        if m != n:
            msg = "given arrays are not square."
            raise ValueError(msg)
        N += n

    arrs = tuple(arr.reshape(-1, arr.shape[-2], arr.shape[-1]) for arr in arrs)
    batch_block_diag = jax.vmap(jsp.linalg.block_diag, in_axes=0, out_axes=0)
    block_diag = batch_block_diag(*arrs)
    return block_diag.reshape(*batch_shape, N, N)

cartesian_product

cartesian_product(*arrays: ComplexArray) -> ComplexArray

Calculate the n-dimensional Cartesian product of input arrays.

Creates all possible combinations of elements from the input arrays, useful for parameter sweeps and grid generation.

Parameters:

Name Type Description Default
*arrays ComplexArray

Variable number of arrays to compute Cartesian product for.

()

Returns:

Type Description
ComplexArray

Array where each row is a unique combination of input elements.

Example
import jax.numpy as jnp

x = jnp.array([1, 2])
y = jnp.array([10, 20])
product = cartesian_product(x, y)
# Result: [[1, 10], [1, 20], [2, 10], [2, 20]]
Source code in src/sax/utils.py
def cartesian_product(*arrays: sax.ComplexArray) -> sax.ComplexArray:
    """Calculate the n-dimensional Cartesian product of input arrays.

    Creates all possible combinations of elements from the input arrays,
    useful for parameter sweeps and grid generation.

    Args:
        *arrays: Variable number of arrays to compute Cartesian product for.

    Returns:
        Array where each row is a unique combination of input elements.

    Example:
        ```python
        import jax.numpy as jnp

        x = jnp.array([1, 2])
        y = jnp.array([10, 20])
        product = cartesian_product(x, y)
        # Result: [[1, 10], [1, 20], [2, 10], [2, 20]]
        ```
    """
    ixarrays = jnp.ix_(*arrays)
    barrays = jnp.broadcast_arrays(*ixarrays)
    sarrays = jnp.stack(barrays, -1)
    product = sarrays.reshape(-1, sarrays.shape[-1])
    return product

circuit

circuit(
    netlist: AnyNetlist,
    models: Models | None = None,
    *,
    backend: BackendLike = "default",
    top_level_name: str = "top_level",
    ignore_impossible_connections: bool = False,
) -> tuple[SDictModel, CircuitInfo]
circuit(
    netlist: AnyNetlist,
    models: Models | None = None,
    *,
    backend: BackendLike = "default",
    return_type: Literal["SDict"],
    top_level_name: str = "top_level",
    ignore_impossible_connections: bool = False,
) -> tuple[SDictModel, CircuitInfo]
circuit(
    netlist: AnyNetlist,
    models: Models | None = None,
    *,
    backend: BackendLike = "default",
    return_type: Literal["SDense"],
    top_level_name: str = "top_level",
    ignore_impossible_connections: bool = False,
) -> tuple[SDenseModel, CircuitInfo]
circuit(
    netlist: AnyNetlist,
    models: Models | None = None,
    *,
    backend: BackendLike = "default",
    return_type: Literal["SCoo"],
    top_level_name: str = "top_level",
    ignore_impossible_connections: bool = False,
) -> tuple[SCooModel, CircuitInfo]
circuit(
    netlist: AnyNetlist,
    models: Models | None = None,
    *,
    backend: BackendLike = "default",
    return_type: Literal["SDict", "SDense", "SCoo"] = "SDict",
    top_level_name: str = "top_level",
    ignore_impossible_connections: bool = False,
) -> tuple[Model, CircuitInfo]

Create a circuit function for a given netlist.

Constructs a circuit model from a netlist description by connecting component models according to the specified connections. The resulting circuit function can be called with parameters to evaluate the overall S-matrix.

Parameters:

Name Type Description Default
netlist AnyNetlist

Circuit netlist specifying instances, connections, and ports. Can be a flat netlist or recursive netlist dictionary.

required
models Models | None

Dictionary mapping component names to their model functions. If None, models must be provided in the netlist itself.

None
backend BackendLike

Circuit analysis backend to use. Options include "default", "klu", "filipsson_gunnar", "additive", "forward". Defaults to "default".

'default'
return_type Literal['SDict', 'SDense', 'SCoo']

Format of the returned S-matrix. Options: "SDict", "SDense", "SCoo". Defaults to "SDict".

'SDict'
top_level_name str

Name of the top-level circuit in recursive netlists. Defaults to "top_level".

'top_level'
ignore_impossible_connections bool

If True, ignore connections to missing instance ports instead of raising an error. Defaults to False.

False

Returns:

Type Description
tuple[Model, CircuitInfo]

Tuple containing: - Circuit model function that accepts parameters and returns S-matrix - CircuitInfo object with DAG, models, and backend information

Raises:

Type Description
RuntimeError

If circuit construction fails.

ValueError

If netlist or models are invalid.

KeyError

If required models are missing.

Example
# Define component models
def waveguide(length=10.0, neff=2.4, wl=1.55):
    phase = 2 * np.pi * neff * length / wl
    return {("in", "out"): np.exp(1j * phase)}


# Create netlist
netlist = {
    "instances": {
        "wg1": {"component": "waveguide", "settings": {"length": 20.0}}
    },
    "connections": {},
    "ports": {"in": "wg1,in", "out": "wg1,out"},
}

# Create circuit
models = {"waveguide": waveguide}
circuit_func, info = circuit(netlist, models)

# Evaluate circuit
s_matrix = circuit_func(wl=1.55)
Source code in src/sax/circuits.py
def circuit(
    netlist: sax.AnyNetlist,
    models: sax.Models | None = None,
    *,
    backend: sax.BackendLike = "default",
    return_type: Literal["SDict", "SDense", "SCoo"] = "SDict",
    top_level_name: str = "top_level",
    ignore_impossible_connections: bool = False,
) -> tuple[sax.Model, sax.CircuitInfo]:
    """Create a circuit function for a given netlist.

    Constructs a circuit model from a netlist description by connecting component
    models according to the specified connections. The resulting circuit function
    can be called with parameters to evaluate the overall S-matrix.

    Args:
        netlist: Circuit netlist specifying instances, connections, and ports.
            Can be a flat netlist or recursive netlist dictionary.
        models: Dictionary mapping component names to their model functions.
            If None, models must be provided in the netlist itself.
        backend: Circuit analysis backend to use. Options include "default",
            "klu", "filipsson_gunnar", "additive", "forward". Defaults to "default".
        return_type: Format of the returned S-matrix. Options: "SDict", "SDense",
            "SCoo". Defaults to "SDict".
        top_level_name: Name of the top-level circuit in recursive netlists.
            Defaults to "top_level".
        ignore_impossible_connections: If True, ignore connections to missing
            instance ports instead of raising an error. Defaults to False.

    Returns:
        Tuple containing:
            - Circuit model function that accepts parameters and returns S-matrix
            - CircuitInfo object with DAG, models, and backend information

    Raises:
        RuntimeError: If circuit construction fails.
        ValueError: If netlist or models are invalid.
        KeyError: If required models are missing.

    Example:
        ```python
        # Define component models
        def waveguide(length=10.0, neff=2.4, wl=1.55):
            phase = 2 * np.pi * neff * length / wl
            return {("in", "out"): np.exp(1j * phase)}


        # Create netlist
        netlist = {
            "instances": {
                "wg1": {"component": "waveguide", "settings": {"length": 20.0}}
            },
            "connections": {},
            "ports": {"in": "wg1,in", "out": "wg1,out"},
        }

        # Create circuit
        models = {"waveguide": waveguide}
        circuit_func, info = circuit(netlist, models)

        # Evaluate circuit
        s_matrix = circuit_func(wl=1.55)
        ```
    """
    _backend = sax.into[sax.Backend](backend)

    instance_models = _extract_instance_models(netlist)
    recnet = into_recnet(
        netlist,
        top_level_name=top_level_name,
    )
    recnet = convert_nets_to_connections(recnet)
    recnet = resolve_array_instances(recnet)
    recnet = remove_unused_instances(recnet)
    _validate_netlist_ports(recnet)
    dependency_dag = _create_dag(recnet, models, validate=True)
    models = _validate_models(
        models or {}, dependency_dag, extra_models=instance_models
    )

    circuit = None
    new_models = {}
    current_models = {}
    model_names = list(nx.topological_sort(dependency_dag))[::-1]
    for model_name in model_names:
        if model_name in models:
            new_models[model_name] = models[model_name]
            continue

        flatnet = recnet[model_name]
        current_models |= new_models
        new_models = {}

        current_models[model_name] = circuit = _flat_circuit(
            flatnet["instances"],
            flatnet.get("connections", {}),
            flatnet["ports"],
            current_models,
            _backend,
            ignore_impossible_connections=ignore_impossible_connections,
        )

    if circuit is None:
        msg = "Could not construct circuit (unknown reason)"
        raise RuntimeError(msg)
    circuit = _enforce_return_type(circuit, return_type)
    return circuit, sax.CircuitInfo(
        dag=dependency_dag,
        models=current_models,
        backend=_backend,
    )

clean_string

clean_string(s: str, dot: str = 'p', minus: str = 'm', other: str = '_') -> Name

Clean a string to create a valid Python identifier.

Converts arbitrary strings to valid Python identifiers by replacing special characters and ensuring the result starts with a letter or underscore.

Parameters:

Name Type Description Default
s str

String to clean.

required
dot str

Replacement for dots. Defaults to "p".

'p'
minus str

Replacement for minus/dash. Defaults to "m".

'm'
other str

Replacement for other special characters. Defaults to "_".

'_'

Returns:

Type Description
Name

Valid Python identifier string.

Raises:

Type Description
ValueError

If cleaning fails to produce a valid identifier.

Example
clean_string("my-component.v2")  # Result: "my_componentpv2"
clean_string("2stage_amp")  # Result: "_2stage_amp"
clean_string("TE-mode")  # Result: "TEm_mode"
Source code in src/sax/utils.py
def clean_string(
    s: str, dot: str = "p", minus: str = "m", other: str = "_"
) -> sax.Name:
    """Clean a string to create a valid Python identifier.

    Converts arbitrary strings to valid Python identifiers by replacing special
    characters and ensuring the result starts with a letter or underscore.

    Args:
        s: String to clean.
        dot: Replacement for dots. Defaults to "p".
        minus: Replacement for minus/dash. Defaults to "m".
        other: Replacement for other special characters. Defaults to "_".

    Returns:
        Valid Python identifier string.

    Raises:
        ValueError: If cleaning fails to produce a valid identifier.

    Example:
        ```python
        clean_string("my-component.v2")  # Result: "my_componentpv2"
        clean_string("2stage_amp")  # Result: "_2stage_amp"
        clean_string("TE-mode")  # Result: "TEm_mode"
        ```
    """
    s = s.strip()
    s = s.replace(".", dot)  # dot
    s = s.replace("-", minus)  # minus
    s = re.sub("[^0-9a-zA-Z_]", other, s)
    if s[0] in "0123456789":
        s = "_" + s
    if not s.isidentifier():
        msg = f"failed to clean string to a valid python identifier: {s}"
        raise ValueError(msg)
    return s

denormalize

denormalize(x: ComplexArray, normalization: Normalization) -> ComplexArray

Denormalize an array using provided normalization parameters.

Reverses z-score normalization: x * std + mean.

Parameters:

Name Type Description Default
x ComplexArray

Normalized array to denormalize.

required
normalization Normalization

Normalization parameters containing mean and std.

required

Returns:

Type Description
ComplexArray

Denormalized array with original scale and offset.

Example
normalized_data = jnp.array([-1, 0, 1])  # normalized
norm_params = Normalization(mean=3.0, std=2.0)
original = denormalize(normalized_data, norm_params)
# Result: [1, 3, 5]
Source code in src/sax/utils.py
def denormalize(x: sax.ComplexArray, normalization: Normalization) -> sax.ComplexArray:
    """Denormalize an array using provided normalization parameters.

    Reverses z-score normalization: x * std + mean.

    Args:
        x: Normalized array to denormalize.
        normalization: Normalization parameters containing mean and std.

    Returns:
        Denormalized array with original scale and offset.

    Example:
        ```python
        normalized_data = jnp.array([-1, 0, 1])  # normalized
        norm_params = Normalization(mean=3.0, std=2.0)
        original = denormalize(normalized_data, norm_params)
        # Result: [1, 3, 5]
        ```
    """
    mean, std = normalization
    return x * std + mean

draw_dag

draw_dag(dag: DiGraph, *, with_labels: bool = True, **kwargs: Any) -> None

Draw a directed acyclic graph (DAG) representing circuit dependencies.

Visualizes the dependency graph of a circuit using networkx. If pydot/graphviz is available, uses hierarchical layout; otherwise falls back to a custom layout.

Parameters:

Name Type Description Default
dag DiGraph

Directed acyclic graph representing circuit component dependencies.

required
with_labels bool

Whether to display node labels. Defaults to True.

True
**kwargs Any

Additional keyword arguments passed to networkx draw function.

{}
Example
import matplotlib.pyplot as plt

# Assuming you have a circuit with dependencies
_, info = circuit(netlist, models)
draw_dag(info.dag)
plt.show()
Source code in src/sax/circuits.py
def draw_dag(dag: nx.DiGraph, *, with_labels: bool = True, **kwargs: Any) -> None:  # noqa: ANN401
    """Draw a directed acyclic graph (DAG) representing circuit dependencies.

    Visualizes the dependency graph of a circuit using networkx. If pydot/graphviz
    is available, uses hierarchical layout; otherwise falls back to a custom layout.

    Args:
        dag: Directed acyclic graph representing circuit component dependencies.
        with_labels: Whether to display node labels. Defaults to True.
        **kwargs: Additional keyword arguments passed to networkx draw function.

    Example:
        ```python
        import matplotlib.pyplot as plt

        # Assuming you have a circuit with dependencies
        _, info = circuit(netlist, models)
        draw_dag(info.dag)
        plt.show()
        ```
    """
    if shutil.which("dot"):
        return nx.draw(
            dag,
            nx.nx_pydot.pydot_layout(dag, prog="dot"),
            with_labels=with_labels,
            **kwargs,
        )
    return nx.draw(dag, _my_dag_pos(dag), with_labels=with_labels, **kwargs)

evaluate_circuit_additive

evaluate_circuit_additive(analyzed: Any, instances: dict[InstanceName, SDict]) -> SDict

Evaluate circuit S-matrix using additive path-based method.

Computes the circuit S-matrix by finding all possible signal paths between external ports and additively combining their contributions. This approach works well for circuits with multiple parallel paths.

The algorithm: 1. Creates a graph representation of the circuit 2. Finds all simple paths between each pair of external ports 3. Calculates the transmission/reflection for each path 4. Sums contributions from all paths

Parameters:

Name Type Description Default
analyzed Any

Circuit analysis data from analyze_circuit_additive containing connections and ports information.

required
instances dict[InstanceName, SDict]

Dictionary mapping instance names to their evaluated S-matrices in SDict format.

required

Returns:

Type Description
SDict

Circuit S-matrix in SDict format with external port combinations as keys.

Example
# Evaluated instance S-matrices
instances = {
    "wg1": {("in", "out"): 0.95 * jnp.exp(1j * 0.1)},
    "dc1": {("in1", "out1"): 0.9, ("in1", "out2"): 0.1},
}
circuit_s = evaluate_circuit_additive(analyzed, instances)
# Result contains S-parameters between external ports
Source code in src/sax/backends/additive.py
def evaluate_circuit_additive(
    analyzed: Any,  # noqa: ANN401
    instances: dict[sax.InstanceName, sax.SDict],
) -> sax.SDict:
    """Evaluate circuit S-matrix using additive path-based method.

    Computes the circuit S-matrix by finding all possible signal paths between
    external ports and additively combining their contributions. This approach
    works well for circuits with multiple parallel paths.

    The algorithm:
    1. Creates a graph representation of the circuit
    2. Finds all simple paths between each pair of external ports
    3. Calculates the transmission/reflection for each path
    4. Sums contributions from all paths

    Args:
        analyzed: Circuit analysis data from analyze_circuit_additive containing
            connections and ports information.
        instances: Dictionary mapping instance names to their evaluated S-matrices
            in SDict format.

    Returns:
        Circuit S-matrix in SDict format with external port combinations as keys.

    Example:
        ```python
        # Evaluated instance S-matrices
        instances = {
            "wg1": {("in", "out"): 0.95 * jnp.exp(1j * 0.1)},
            "dc1": {("in1", "out1"): 0.9, ("in1", "out2"): 0.1},
        }
        circuit_s = evaluate_circuit_additive(analyzed, instances)
        # Result contains S-parameters between external ports
        ```
    """
    connections, ports = analyzed
    edges = _graph_edges(instances, connections, ports)

    graph = nx.Graph()
    graph.add_edges_from(edges)
    _prune_internal_output_nodes(graph)

    sdict = {}
    for source in ports:
        for target in ports:
            paths = _get_possible_paths(graph, source=("", source), target=("", target))
            if not paths:
                continue
            sdict[source, target] = _path_lengths(graph, paths)

    return sdict

evaluate_circuit_fg

evaluate_circuit_fg(analyzed: Any, instances: dict[str, SType]) -> SDict

Evaluate circuit S-matrix using the Filipsson-Gunnar algorithm.

Computes the overall circuit S-matrix by systematically interconnecting multiport networks using the Filipsson-Gunnar algorithm. This method iteratively applies equation 6 from the original paper to connect ports.

The algorithm: 1. Creates a block diagonal S-matrix from all component S-matrices 2. Iteratively interconnects ports according to the connections 3. Applies the Filipsson-Gunnar interconnection formula for each connection 4. Extracts the final S-matrix for external ports

Parameters:

Name Type Description Default
analyzed Any

Circuit analysis data from analyze_circuit_fg containing connections and ports information.

required
instances dict[str, SType]

Dictionary mapping instance names to their evaluated S-matrices in any SAX format (will be converted to SDict).

required

Returns:

Type Description
SDict

Circuit S-matrix in SDict format with external port combinations as keys.

Note

The interconnection formula used is equation 6 from: Filipsson, Gunnar. "A new general computer algorithm for S-matrix calculation of interconnected multiports." 11th European Microwave Conference. IEEE, 1981.

Example
# Circuit analysis and instances
analyzed = (connections, ports)
instances = {
    "wg1": {("in", "out"): 0.95 * jnp.exp(1j * 0.1)},
    "dc1": {("in1", "out1"): 0.9, ("in1", "out2"): 0.1},
}
circuit_s = evaluate_circuit_fg(analyzed, instances)
Source code in src/sax/backends/filipsson_gunnar.py
def evaluate_circuit_fg(
    analyzed: Any,  # noqa: ANN401
    instances: dict[str, sax.SType],
) -> sax.SDict:
    """Evaluate circuit S-matrix using the Filipsson-Gunnar algorithm.

    Computes the overall circuit S-matrix by systematically interconnecting
    multiport networks using the Filipsson-Gunnar algorithm. This method
    iteratively applies equation 6 from the original paper to connect ports.

    The algorithm:
    1. Creates a block diagonal S-matrix from all component S-matrices
    2. Iteratively interconnects ports according to the connections
    3. Applies the Filipsson-Gunnar interconnection formula for each connection
    4. Extracts the final S-matrix for external ports

    Args:
        analyzed: Circuit analysis data from analyze_circuit_fg containing
            connections and ports information.
        instances: Dictionary mapping instance names to their evaluated S-matrices
            in any SAX format (will be converted to SDict).

    Returns:
        Circuit S-matrix in SDict format with external port combinations as keys.

    Note:
        The interconnection formula used is equation 6 from:
        Filipsson, Gunnar. "A new general computer algorithm for S-matrix
        calculation of interconnected multiports." 11th European Microwave
        Conference. IEEE, 1981.

    Example:
        ```python
        # Circuit analysis and instances
        analyzed = (connections, ports)
        instances = {
            "wg1": {("in", "out"): 0.95 * jnp.exp(1j * 0.1)},
            "dc1": {("in1", "out1"): 0.9, ("in1", "out2"): 0.1},
        }
        circuit_s = evaluate_circuit_fg(analyzed, instances)
        ```
    """
    connections, ports = analyzed

    # it's actually easier working w reverse:
    reversed_ports = {v: k for k, v in ports.items()}

    block_diag = {}
    for name, S in instances.items():
        block_diag.update(
            {
                (f"{name},{p1}", f"{name},{p2}"): v
                for (p1, p2), v in sax.sdict(S).items()
            }
        )

    sorted_connections = sorted(connections.items(), key=_connections_sort_key)
    all_connected_instances = {k: {k} for k in instances}

    for k, l in sorted_connections:
        name1, _ = k.split(",")
        name2, _ = l.split(",")

        connected_instances = (
            all_connected_instances[name1] | all_connected_instances[name2]
        )
        for name in connected_instances:
            all_connected_instances[name] = connected_instances

        current_ports = tuple(
            p
            for instance in connected_instances
            for p in set([p for p, _ in block_diag] + [p for _, p in block_diag])
            if p.startswith(f"{instance},")
        )

        block_diag.update(_interconnect_ports(block_diag, current_ports, k, l))

        for i, j in list(block_diag.keys()):
            is_connected = j in (k, l) or i in (k, l)
            is_in_output_ports = i in reversed_ports and j in reversed_ports
            if is_connected and not is_in_output_ports:
                del block_diag[
                    i, j
                ]  # we're no longer interested in these port combinations

    circuit_sdict: sax.SDict = {
        (reversed_ports[i], reversed_ports[j]): v
        for (i, j), v in block_diag.items()
        if i in reversed_ports and j in reversed_ports
    }
    return circuit_sdict

evaluate_circuit_forward

evaluate_circuit_forward(analyzed: Any, instances: dict[str, SDict]) -> SDict

Evaluate circuit S-matrix using forward-only propagation.

Computes the circuit response using a simplified forward propagation approach. This method assumes unidirectional signal flow and uses breadth-first search to propagate signals through the circuit without considering reflections.

The algorithm: 1. Creates a directed graph representation of the circuit 2. For each input port, injects a unit signal 3. Uses BFS to propagate signals through the circuit 4. Records the signal levels at output ports

Parameters:

Name Type Description Default
analyzed Any

Circuit analysis data from analyze_circuit_forward containing connections and ports information.

required
instances dict[str, SDict]

Dictionary mapping instance names to their evaluated S-matrices in SDict format.

required

Returns:

Type Description
SDict

Circuit S-matrix in SDict format, typically containing only forward

SDict

transmission terms (input to output ports).

Warning

This backend is only accurate for feed-forward circuits without reflections or feedback paths. For circuits with bidirectional coupling or reflections, use the Filipsson-Gunnar or KLU backends instead.

Example
# Circuit analysis and instances (feed-forward only)
analyzed = (connections, ports)
instances = {
    "wg1": {("in", "out"): 0.95},  # Low-loss waveguide
    "amp1": {("in", "out"): 10.0},  # 20dB amplifier
}
circuit_s = evaluate_circuit_forward(analyzed, instances)
# Result contains only forward transmission terms
Source code in src/sax/backends/forward_only.py
def evaluate_circuit_forward(
    analyzed: Any,  # noqa: ANN401
    instances: dict[str, sax.SDict],
) -> sax.SDict:
    """Evaluate circuit S-matrix using forward-only propagation.

    Computes the circuit response using a simplified forward propagation approach.
    This method assumes unidirectional signal flow and uses breadth-first search
    to propagate signals through the circuit without considering reflections.

    The algorithm:
    1. Creates a directed graph representation of the circuit
    2. For each input port, injects a unit signal
    3. Uses BFS to propagate signals through the circuit
    4. Records the signal levels at output ports

    Args:
        analyzed: Circuit analysis data from analyze_circuit_forward containing
            connections and ports information.
        instances: Dictionary mapping instance names to their evaluated S-matrices
            in SDict format.

    Returns:
        Circuit S-matrix in SDict format, typically containing only forward
        transmission terms (input to output ports).

    Warning:
        This backend is only accurate for feed-forward circuits without
        reflections or feedback paths. For circuits with bidirectional coupling
        or reflections, use the Filipsson-Gunnar or KLU backends instead.

    Example:
        ```python
        # Circuit analysis and instances (feed-forward only)
        analyzed = (connections, ports)
        instances = {
            "wg1": {("in", "out"): 0.95},  # Low-loss waveguide
            "amp1": {("in", "out"): 10.0},  # 20dB amplifier
        }
        circuit_s = evaluate_circuit_forward(analyzed, instances)
        # Result contains only forward transmission terms
        ```
    """
    connections, ports = analyzed
    edges = _graph_edges_directed(instances, connections, ports)

    graph = nx.DiGraph()
    graph.add_edges_from(edges)

    # Dictionary to store signals at each node
    circuit_sdict = {}
    for in_port in ports:
        if in_port.startswith("in"):
            node_signals = {("", in_port): 1}
            bfs_output = nx.bfs_layers(graph, ("", in_port))
            for layer in bfs_output:
                layer_signals = {}
                for node in layer:
                    if node in node_signals:
                        signal = node_signals[node]
                        for neighbor in graph.successors(node):
                            transmission = graph[node][neighbor]["transmission"]
                            if neighbor in layer_signals:
                                layer_signals[neighbor] += signal * transmission
                            else:
                                layer_signals[neighbor] = signal * transmission
                node_signals.update(layer_signals)
            sdict = {
                (in_port, p2): v
                for (p1, p2), v in node_signals.items()
                if p1 == "" and p2.startswith("out")
            }
            circuit_sdict.update(sdict)
    return circuit_sdict

flatten_dict

flatten_dict(dic: dict[str, Any], sep: str = ',') -> dict[str, Any]

Flatten a nested dictionary into a single-level dictionary.

Converts a nested dictionary into a flat dictionary by concatenating nested keys with a separator.

Parameters:

Name Type Description Default
dic dict[str, Any]

Nested dictionary to flatten.

required
sep str

Separator to use between nested keys. Defaults to ",".

','

Returns:

Type Description
dict[str, Any]

Flattened dictionary with concatenated keys.

Example
nested = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
flat = flatten_dict(nested)
# Result: {"a,b": 1, "a,c,d": 2, "e": 3}
Source code in src/sax/utils.py
def flatten_dict(dic: dict[str, Any], sep: str = ",") -> dict[str, Any]:
    """Flatten a nested dictionary into a single-level dictionary.

    Converts a nested dictionary into a flat dictionary by concatenating
    nested keys with a separator.

    Args:
        dic: Nested dictionary to flatten.
        sep: Separator to use between nested keys. Defaults to ",".

    Returns:
        Flattened dictionary with concatenated keys.

    Example:
        ```python
        nested = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
        flat = flatten_dict(nested)
        # Result: {"a,b": 1, "a,c,d": 2, "e": 3}
        ```
    """
    return _flatten_dict(dic, sep=sep)

flatten_netlist

flatten_netlist(recnet: RecursiveNetlist, sep: str = '~') -> Netlist

Flatten a recursive netlist into a single flat netlist.

Converts a hierarchical (recursive) netlist into a single flat netlist by inlining all sub-circuits. Instance names are prefixed to avoid conflicts.

Parameters:

Name Type Description Default
recnet RecursiveNetlist

Recursive netlist to flatten.

required
sep str

Separator used for hierarchical instance naming. Defaults to "~".

'~'

Returns:

Type Description
Netlist

Single flat netlist with all hierarchies inlined.

Example
# Flatten a hierarchical netlist
recnet = {
    "top": {
        "instances": {"sub1": {"component": "subcircuit"}},
        "ports": {"in": "sub1,in"},
    },
    "subcircuit": {
        "instances": {"wg1": {"component": "waveguide"}},
        "ports": {"in": "wg1,in"},
    },
}
flat = flatten_netlist(recnet)
# Result has instances like "sub1~wg1" for the flattened hierarchy
Source code in src/sax/netlists.py
def flatten_netlist(recnet: sax.RecursiveNetlist, sep: str = "~") -> sax.Netlist:
    """Flatten a recursive netlist into a single flat netlist.

    Converts a hierarchical (recursive) netlist into a single flat netlist by
    inlining all sub-circuits. Instance names are prefixed to avoid conflicts.

    Args:
        recnet: Recursive netlist to flatten.
        sep: Separator used for hierarchical instance naming. Defaults to "~".

    Returns:
        Single flat netlist with all hierarchies inlined.

    Example:
        ```python
        # Flatten a hierarchical netlist
        recnet = {
            "top": {
                "instances": {"sub1": {"component": "subcircuit"}},
                "ports": {"in": "sub1,in"},
            },
            "subcircuit": {
                "instances": {"wg1": {"component": "waveguide"}},
                "ports": {"in": "wg1,in"},
            },
        }
        flat = flatten_netlist(recnet)
        # Result has instances like "sub1~wg1" for the flattened hierarchy
        ```
    """
    first_name = next(iter(recnet.keys()))
    net = deepcopy(recnet[first_name])
    _flatten_netlist_into(recnet, net, sep)
    return net

get_mode

get_mode(pm: PortMode) -> Mode

Extract the mode from a port@mode string.

Parses the mode identifier from the standard port@mode naming convention used in multimode S-matrices.

Parameters:

Name Type Description Default
pm PortMode

Port-mode string in the format "port@mode".

required

Returns:

Type Description
Mode

The mode identifier (part after @).

Example
mode = get_mode("waveguide1@TE")
# Result: "TE"

mode = get_mode("input@TM")
# Result: "TM"
Source code in src/sax/s.py
@validate_call
def get_mode(pm: sax.PortMode) -> sax.Mode:
    """Extract the mode from a port@mode string.

    Parses the mode identifier from the standard port@mode naming convention
    used in multimode S-matrices.

    Args:
        pm: Port-mode string in the format "port@mode".

    Returns:
        The mode identifier (part after @).

    Example:
        ```python
        mode = get_mode("waveguide1@TE")
        # Result: "TE"

        mode = get_mode("input@TM")
        # Result: "TM"
        ```
    """
    return pm.split("@")[1]

get_modes

get_modes(S: STypeMM) -> tuple[Mode, ...]

Extract the optical modes from a multimode S-matrix.

Returns all unique optical modes (e.g., "TE", "TM") present in the multimode S-matrix port names.

Parameters:

Name Type Description Default
S STypeMM

Multimode S-matrix in any format with port@mode naming convention.

required

Returns:

Type Description
tuple[Mode, ...]

Tuple of unique mode names found in the S-matrix.

Example
# Multimode S-matrix with TE and TM modes
s_mm = {
    ("in@TE", "out@TE"): 0.9,
    ("in@TM", "out@TM"): 0.8,
    ("in@TE", "out@TM"): 0.1,
}
modes = get_modes(s_mm)
# Result: ("TE", "TM")
Source code in src/sax/s.py
@validate_call
def get_modes(S: sax.STypeMM) -> tuple[sax.Mode, ...]:
    """Extract the optical modes from a multimode S-matrix.

    Returns all unique optical modes (e.g., "TE", "TM") present in the
    multimode S-matrix port names.

    Args:
        S: Multimode S-matrix in any format with port@mode naming convention.

    Returns:
        Tuple of unique mode names found in the S-matrix.

    Example:
        ```python
        # Multimode S-matrix with TE and TM modes
        s_mm = {
            ("in@TE", "out@TE"): 0.9,
            ("in@TM", "out@TM"): 0.8,
            ("in@TE", "out@TM"): 0.1,
        }
        modes = get_modes(s_mm)
        # Result: ("TE", "TM")
        ```
    """
    return tuple(get_mode(pm) for pm in get_ports(S))

get_port_combinations

get_port_combinations(S: Model | SType) -> tuple[tuple[str, str], ...]

Extract all port pair combinations from an S-matrix.

Returns all (input_port, output_port) combinations present in the S-matrix. This is useful for understanding the connectivity structure of the device.

Parameters:

Name Type Description Default
S Model | SType

S-matrix in any format (SDict, SCoo, or SDense).

required

Returns:

Type Description
tuple[tuple[str, str], ...]

Tuple of (port1, port2) combinations representing S-parameter entries.

Raises:

Type Description
TypeError

If input is a model function (not evaluated).

Example
# S-matrix with cross-coupling
s_matrix = {
    ("in1", "out1"): 0.9,
    ("in1", "out2"): 0.1,
    ("in2", "out1"): 0.1,
    ("in2", "out2"): 0.9,
}
combinations = get_port_combinations(s_matrix)
# Result:
# (("in1", "out1"), ("in1", "out2"), ("in2", "out1"), ("in2", "out2"))
Source code in src/sax/s.py
def get_port_combinations(S: sax.Model | sax.SType) -> tuple[tuple[str, str], ...]:
    """Extract all port pair combinations from an S-matrix.

    Returns all (input_port, output_port) combinations present in the S-matrix.
    This is useful for understanding the connectivity structure of the device.

    Args:
        S: S-matrix in any format (SDict, SCoo, or SDense).

    Returns:
        Tuple of (port1, port2) combinations representing S-parameter entries.

    Raises:
        TypeError: If input is a model function (not evaluated).

    Example:
        ```python
        # S-matrix with cross-coupling
        s_matrix = {
            ("in1", "out1"): 0.9,
            ("in1", "out2"): 0.1,
            ("in2", "out1"): 0.1,
            ("in2", "out2"): 0.9,
        }
        combinations = get_port_combinations(s_matrix)
        # Result:
        # (("in1", "out1"), ("in1", "out2"), ("in2", "out1"), ("in2", "out2"))
        ```
    """
    if callable(S):
        msg = (
            "Getting the port combinations of a model is no longer supported. "
            "Please Evaluate the model first: Use get_ports(model()) in stead of "
            f"get_ports(model). Got: {S}"
        )
        raise TypeError(msg)
    if isinstance(sdict := S, dict):
        return tuple(sdict.keys())
    if len(scoo := cast(sax.SCoo, S)) == 4:
        Si, Sj, _, pm = scoo
        rpm = {int(i): str(p) for p, i in pm.items()}
        return tuple(
            natsorted((rpm[int(i)], rpm[int(j)]) for i, j in zip(Si, Sj, strict=False))
        )
    if len(sdense := cast(sax.SDense, S)) == 2:
        _, pm = sdense
        return tuple(natsorted((p1, p2) for p1 in pm for p2 in pm))
    msg = "Could not extract ports for given S"
    raise ValueError(msg)

get_port_naming_strategy

get_port_naming_strategy() -> PortNamingStrategy

Get the current port naming strategy.

Returns the currently active port naming strategy that determines how ports are named in SAX model functions.

Returns:

Type Description
PortNamingStrategy

Current port naming strategy ("optical" or "inout").

Examples:

Check current strategy:

import sax.models

strategy = sax.models.get_port_naming_strategy()
print(f"Current strategy: {strategy}")

Use in conditional logic:

if sax.models.get_port_naming_strategy() == "optical":
    print("Using optical port naming (o1, o2, ...)")
else:
    print("Using input/output port naming (in0, out0, ...)")
Source code in src/sax/ports.py
def get_port_naming_strategy() -> PortNamingStrategy:
    """Get the current port naming strategy.

    Returns the currently active port naming strategy that determines how
    ports are named in SAX model functions.

    Returns:
        Current port naming strategy ("optical" or "inout").

    Examples:
        Check current strategy:

        ```python
        import sax.models

        strategy = sax.models.get_port_naming_strategy()
        print(f"Current strategy: {strategy}")
        ```

        Use in conditional logic:

        ```python
        if sax.models.get_port_naming_strategy() == "optical":
            print("Using optical port naming (o1, o2, ...)")
        else:
            print("Using input/output port naming (in0, out0, ...)")
        ```
    """
    return PORT_NAMING_STRATEGY

get_ports

get_ports(S: SType) -> tuple[Port, ...] | tuple[PortMode]

Extract port names from an S-matrix.

Returns the port names present in the S-matrix in natural sorted order. For multimode S-matrices, returns port@mode combinations.

Parameters:

Name Type Description Default
S SType

S-matrix in any format (SDict, SCoo, or SDense).

required

Returns:

Type Description
tuple[Port, ...] | tuple[PortMode]

Tuple of port names (for single-mode) or port@mode strings (for multimode).

Raises:

Type Description
TypeError

If input is a model function (not evaluated) or invalid type.

Example
# Single-mode matrix
s_matrix = {("in", "out"): 0.9, ("out", "in"): 0.9}
ports = get_ports(s_matrix)
# Result: ("in", "out")

# Multimode matrix
s_mm = {("in@TE", "out@TE"): 0.9, ("in@TM", "out@TM"): 0.8}
ports_mm = get_ports(s_mm)
# Result: ("in@TE", "in@TM", "out@TE", "out@TM")
Source code in src/sax/s.py
def get_ports(S: sax.SType) -> tuple[sax.Port, ...] | tuple[sax.PortMode]:
    """Extract port names from an S-matrix.

    Returns the port names present in the S-matrix in natural sorted order.
    For multimode S-matrices, returns port@mode combinations.

    Args:
        S: S-matrix in any format (SDict, SCoo, or SDense).

    Returns:
        Tuple of port names (for single-mode) or port@mode strings (for multimode).

    Raises:
        TypeError: If input is a model function (not evaluated) or invalid type.

    Example:
        ```python
        # Single-mode matrix
        s_matrix = {("in", "out"): 0.9, ("out", "in"): 0.9}
        ports = get_ports(s_matrix)
        # Result: ("in", "out")

        # Multimode matrix
        s_mm = {("in@TE", "out@TE"): 0.9, ("in@TM", "out@TM"): 0.8}
        ports_mm = get_ports(s_mm)
        # Result: ("in@TE", "in@TM", "out@TE", "out@TM")
        ```
    """
    if callable(S):
        msg = (
            "Getting the ports of a model is no longer supported. "
            "Please Evaluate the model first: Use get_ports(model()) in stead of "
            f"get_ports(model). Got: {S}"
        )
        raise TypeError(msg)
    if isinstance(sdict := S, dict):
        ports_set = {p1 for p1, _ in sdict} | {p2 for _, p2 in sdict}
        return tuple(natsorted(ports_set))

    if isinstance(S, tuple):
        *_, pm = cast(sax.SCoo | sax.SDense, S)
        return tuple(natsorted(pm.keys()))

    msg = f"Expected an SType. Got: {S!r} [{type(S)}]"
    raise TypeError(msg)

get_required_circuit_models

get_required_circuit_models(
    netlist: AnyNetlist, models: dict[str, Model] | None = None
) -> list[str]

Determine which component models are required for a given netlist.

Analyzes a netlist to identify all component types that need model functions. This is useful for validating that all required models are available before circuit construction.

Parameters:

Name Type Description Default
netlist AnyNetlist

Circuit netlist to analyze for component dependencies.

required
models dict[str, Model] | None

Optional dictionary of available models. Used to filter out models that are already provided.

None

Returns:

Type Description
list[str]

List of component names that require model functions.

Example
netlist = {
    "instances": {
        "wg1": {"component": "waveguide"},
        "dc1": {"component": "directional_coupler"},
    },
    "ports": {"in": "wg1,in", "out": "dc1,out"},
}
required = get_required_circuit_models(netlist)
# Result: ["waveguide", "directional_coupler"]

# With some models already available
models = {"waveguide": my_waveguide_model}
required = get_required_circuit_models(netlist, models)
# Result: ["directional_coupler"]
Source code in src/sax/circuits.py
def get_required_circuit_models(
    netlist: sax.AnyNetlist,
    models: dict[str, sax.Model] | None = None,
) -> list[str]:
    """Determine which component models are required for a given netlist.

    Analyzes a netlist to identify all component types that need model functions.
    This is useful for validating that all required models are available before
    circuit construction.

    Args:
        netlist: Circuit netlist to analyze for component dependencies.
        models: Optional dictionary of available models. Used to filter out
            models that are already provided.

    Returns:
        List of component names that require model functions.

    Example:
        ```python
        netlist = {
            "instances": {
                "wg1": {"component": "waveguide"},
                "dc1": {"component": "directional_coupler"},
            },
            "ports": {"in": "wg1,in", "out": "dc1,out"},
        }
        required = get_required_circuit_models(netlist)
        # Result: ["waveguide", "directional_coupler"]

        # With some models already available
        models = {"waveguide": my_waveguide_model}
        required = get_required_circuit_models(netlist, models)
        # Result: ["directional_coupler"]
        ```
    """
    instance_models = _extract_instance_models(netlist)
    recnet = into_recnet(
        netlist,
    )
    recnet = remove_unused_instances(recnet)
    dependency_dag = _create_dag(recnet, models, validate=True)
    _, required, _ = _find_missing_models(
        models,
        dependency_dag,
        extra_models=instance_models,
    )
    return required

get_settings

get_settings(model: Model | ModelFactory) -> Settings

Extract default parameter settings from a SAX model function.

Inspects a model function's signature to extract default parameter values. This is useful for understanding what parameters a model accepts and their default values.

Parameters:

Name Type Description Default
model Model | ModelFactory

SAX model function or model factory to inspect.

required

Returns:

Type Description
Settings

Dictionary of parameter names and their default values.

Example
def my_model(wl=1.55, length=10.0, neff=2.4):
    return some_s_matrix


settings = get_settings(my_model)
# Result: {"wl": 1.55, "length": 10.0, "neff": 2.4}
Source code in src/sax/utils.py
def get_settings(model: sax.Model | sax.ModelFactory) -> sax.Settings:
    """Extract default parameter settings from a SAX model function.

    Inspects a model function's signature to extract default parameter values.
    This is useful for understanding what parameters a model accepts and their
    default values.

    Args:
        model: SAX model function or model factory to inspect.

    Returns:
        Dictionary of parameter names and their default values.

    Example:
        ```python
        def my_model(wl=1.55, length=10.0, neff=2.4):
            return some_s_matrix


        settings = get_settings(my_model)
        # Result: {"wl": 1.55, "length": 10.0, "neff": 2.4}
        ```
    """
    signature = inspect.signature(model)
    settings: sax.Settings = {
        k: (v.default if not isinstance(v, dict) else v)
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }
    return cast(sax.Settings, settings)

grouped_interp

grouped_interp(wl: FloatArray, wls: FloatArray, phis: FloatArray) -> FloatArray

Perform grouped phase interpolation for optical phase data.

Grouped interpolation is useful for interpolating phase values where each datapoint is doubled (very close together) to give an indication of the phase variation at that point. This is common in optical simulations where phase unwrapping is needed.

Parameters:

Name Type Description Default
wl FloatArray

Wavelength points where interpolation is desired.

required
wls FloatArray

Reference wavelength points (1D array).

required
phis FloatArray

Phase values at reference wavelengths (1D array).

required

Returns:

Type Description
FloatArray

Interpolated phase values at the requested wavelengths.

Warning

This interpolation is only accurate in the range [wls[0], wls[-2]) (wls[-2] not included). Any extrapolation outside these bounds can yield unexpected results!

Raises:

Type Description
ValueError

If wls or phis are not 1D arrays or have mismatched shapes.

Example
import jax.numpy as jnp

# Reference phase data with grouped points
wls_ref = jnp.array([1.50, 1.501, 1.55, 1.551, 1.60, 1.601])
phis_ref = jnp.array([0.1, 0.11, 0.5, 0.51, 1.0, 1.01])

# Interpolate at new wavelengths
wl_new = jnp.linspace(1.51, 1.59, 10)
phis_interp = grouped_interp(wl_new, wls_ref, phis_ref)
Source code in src/sax/utils.py
def grouped_interp(
    wl: sax.FloatArray, wls: sax.FloatArray, phis: sax.FloatArray
) -> sax.FloatArray:
    """Perform grouped phase interpolation for optical phase data.

    Grouped interpolation is useful for interpolating phase values where each
    datapoint is doubled (very close together) to give an indication of the
    phase variation at that point. This is common in optical simulations where
    phase unwrapping is needed.

    Args:
        wl: Wavelength points where interpolation is desired.
        wls: Reference wavelength points (1D array).
        phis: Phase values at reference wavelengths (1D array).

    Returns:
        Interpolated phase values at the requested wavelengths.

    Warning:
        This interpolation is only accurate in the range [wls[0], wls[-2])
        (wls[-2] not included). Any extrapolation outside these bounds can
        yield unexpected results!

    Raises:
        ValueError: If wls or phis are not 1D arrays or have mismatched shapes.

    Example:
        ```python
        import jax.numpy as jnp

        # Reference phase data with grouped points
        wls_ref = jnp.array([1.50, 1.501, 1.55, 1.551, 1.60, 1.601])
        phis_ref = jnp.array([0.1, 0.11, 0.5, 0.51, 1.0, 1.01])

        # Interpolate at new wavelengths
        wl_new = jnp.linspace(1.51, 1.59, 10)
        phis_interp = grouped_interp(wl_new, wls_ref, phis_ref)
        ```
    """
    wl = jnp.asarray(wl)
    wls = jnp.asarray(wls)
    phis = jnp.asarray(phis) % (2 * jnp.pi)
    phis = jnp.where(phis > jnp.pi, phis - 2 * jnp.pi, phis)
    if wls.ndim != 1:
        msg = "grouped_interp: wls should be a 1D array"
        raise ValueError(msg)
    if phis.ndim != 1:
        msg = "grouped_interp: wls should be a 1D array"
        raise ValueError(msg)
    if wls.shape != phis.shape:
        msg = "grouped_interp: wls and phis shape does not match"
        raise ValueError(msg)
    return _grouped_interp(wl.reshape(-1), wls, phis).reshape(*wl.shape)

hash_dict

hash_dict(dic: dict) -> int

Compute a hash value for a dictionary.

Creates a deterministic hash of a dictionary that can contain NumPy arrays and nested structures. Useful for caching and change detection.

Parameters:

Name Type Description Default
dic dict

Dictionary to hash.

required

Returns:

Type Description
int

Integer hash value.

Example
settings = {"wl": 1.55, "length": 10.0}
hash_val = hash_dict(settings)
# Same settings will always produce the same hash
Source code in src/sax/utils.py
def hash_dict(dic: dict) -> int:
    """Compute a hash value for a dictionary.

    Creates a deterministic hash of a dictionary that can contain NumPy arrays
    and nested structures. Useful for caching and change detection.

    Args:
        dic: Dictionary to hash.

    Returns:
        Integer hash value.

    Example:
        ```python
        settings = {"wl": 1.55, "length": 10.0}
        hash_val = hash_dict(settings)
        # Same settings will always produce the same hash
        ```
    """
    return int(
        md5(  # noqa: S324
            orjson.dumps(
                _numpyfy(dic), option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS
            )
        ).hexdigest(),
        16,
    )

huber_loss

huber_loss(x: ComplexArray, y: ComplexArray, delta: float = 0.5) -> float

Compute Huber loss between two complex arrays.

The Huber loss combines the best properties of MSE and MAE losses. It is quadratic for small errors and linear for large errors, making it robust to outliers.

Parameters:

Name Type Description Default
x ComplexArray

First complex array for comparison.

required
y ComplexArray

Second complex array for comparison.

required
delta float

Threshold parameter controlling the transition between quadratic and linear regions. Defaults to 0.5.

0.5

Returns:

Type Description
float

Huber loss as a float value.

Example
import jax.numpy as jnp

x = jnp.array([1 + 2j, 3 + 4j])
y = jnp.array([1 + 1j, 3 + 3j])
loss = huber_loss(x, y, delta=1.0)
Source code in src/sax/loss.py
def huber_loss(x: ComplexArray, y: ComplexArray, delta: float = 0.5) -> float:
    """Compute Huber loss between two complex arrays.

    The Huber loss combines the best properties of MSE and MAE losses. It is
    quadratic for small errors and linear for large errors, making it robust
    to outliers.

    Args:
        x: First complex array for comparison.
        y: Second complex array for comparison.
        delta: Threshold parameter controlling the transition between quadratic
            and linear regions. Defaults to 0.5.

    Returns:
        Huber loss as a float value.

    Example:
        ```python
        import jax.numpy as jnp

        x = jnp.array([1 + 2j, 3 + 4j])
        y = jnp.array([1 + 1j, 3 + 3j])
        loss = huber_loss(x, y, delta=1.0)
        ```
    """
    loss = ((delta**2) * ((1.0 + (abs(x - y) / delta) ** 2) ** 0.5 - 1.0)).mean()
    return cast(float, loss)

l2_reg

l2_reg(weights: dict[str, ComplexArray]) -> float

Compute L2 regularization loss for model weights.

L2 regularization adds a penalty term proportional to the sum of squares of the model parameters, helping to prevent overfitting.

Parameters:

Name Type Description Default
weights dict[str, ComplexArray]

Dictionary of weight arrays where keys starting with 'w' or 'b' are considered for regularization (typically weights and biases).

required

Returns:

Type Description
float

L2 regularization loss normalized by total number of elements.

Example
import jax.numpy as jnp

weights = {
    "w1": jnp.array([1 + 1j, 2 + 2j]),
    "b1": jnp.array([0.1 + 0.2j]),
    "other": jnp.array([5 + 5j]),  # ignored (doesn't start with w/b)
}
reg_loss = l2_reg(weights)
Source code in src/sax/loss.py
def l2_reg(weights: dict[str, ComplexArray]) -> float:
    """Compute L2 regularization loss for model weights.

    L2 regularization adds a penalty term proportional to the sum of squares
    of the model parameters, helping to prevent overfitting.

    Args:
        weights: Dictionary of weight arrays where keys starting with 'w' or 'b'
            are considered for regularization (typically weights and biases).

    Returns:
        L2 regularization loss normalized by total number of elements.

    Example:
        ```python
        import jax.numpy as jnp

        weights = {
            "w1": jnp.array([1 + 1j, 2 + 2j]),
            "b1": jnp.array([0.1 + 0.2j]),
            "other": jnp.array([5 + 5j]),  # ignored (doesn't start with w/b)
        }
        reg_loss = l2_reg(weights)
        ```
    """
    numel = 0
    loss = 0.0
    for w in (v for k, v in weights.items() if k[0] in ("w", "b")):
        numel = numel + w.size
        loss = loss + (jnp.abs(w) ** 2).sum()
    return cast(float, loss / numel)

load_netlist

load_netlist(content_or_filename: str | Path | IOLike) -> Netlist

Load a SAX netlist from YAML content or file.

Parses YAML content to create a SAX netlist dictionary. The YAML should contain instances, connections, and ports sections.

Parameters:

Name Type Description Default
content_or_filename str | Path | IOLike

YAML content as string, file path, or file-like object.

required

Returns:

Type Description
Netlist

Parsed netlist dictionary.

Example
# Load from file
netlist = load_netlist("circuit.yml")

# Load from YAML string
yaml_content = '''
instances:
  wg1:
    component: waveguide
ports:
  in: wg1,in
  out: wg1,out
'''
netlist = load_netlist(yaml_content)
Source code in src/sax/utils.py
def load_netlist(content_or_filename: str | Path | sax.IOLike) -> sax.Netlist:
    """Load a SAX netlist from YAML content or file.

    Parses YAML content to create a SAX netlist dictionary. The YAML should
    contain instances, connections, and ports sections.

    Args:
        content_or_filename: YAML content as string, file path, or file-like object.

    Returns:
        Parsed netlist dictionary.

    Example:
        ```python
        # Load from file
        netlist = load_netlist("circuit.yml")

        # Load from YAML string
        yaml_content = '''
        instances:
          wg1:
            component: waveguide
        ports:
          in: wg1,in
          out: wg1,out
        '''
        netlist = load_netlist(yaml_content)
        ```
    """
    return yaml.safe_load(read(content_or_filename))

load_recursive_netlist

load_recursive_netlist(
    top_level_path: str | Path, ext: str = ".pic.yml"
) -> RecursiveNetlist

Load a SAX recursive netlist from a directory of YAML files.

Recursively loads all YAML files with the specified extension from a directory to create a recursive netlist. Each file becomes a component in the recursive netlist, with the filename (without extension) as the component name.

Parameters:

Name Type Description Default
top_level_path str | Path

Path to the top-level netlist file.

required
ext str

File extension to search for. Defaults to ".pic.yml".

'.pic.yml'

Returns:

Type Description
RecursiveNetlist

Recursive netlist dictionary mapping component names to their netlists.

Example
# Load all .pic.yml files in directory
recnet = load_recursive_netlist("circuits/main.pic.yml")
# Result: {"main": {...}, "component1": {...}, "component2": {...}}
Source code in src/sax/utils.py
def load_recursive_netlist(
    top_level_path: str | Path,
    ext: str = ".pic.yml",
) -> sax.RecursiveNetlist:
    """Load a SAX recursive netlist from a directory of YAML files.

    Recursively loads all YAML files with the specified extension from a directory
    to create a recursive netlist. Each file becomes a component in the recursive
    netlist, with the filename (without extension) as the component name.

    Args:
        top_level_path: Path to the top-level netlist file.
        ext: File extension to search for. Defaults to ".pic.yml".

    Returns:
        Recursive netlist dictionary mapping component names to their netlists.

    Example:
        ```python
        # Load all .pic.yml files in directory
        recnet = load_recursive_netlist("circuits/main.pic.yml")
        # Result: {"main": {...}, "component1": {...}, "component2": {...}}
        ```
    """
    top_level_path = Path(top_level_path)
    folder_path = top_level_path.parent

    def _net_name(path: Path) -> sax.Name:
        return clean_string(path.name.removesuffix(ext))

    recnet = {_net_name(top_level_path): load_netlist(top_level_path)}

    for path in folder_path.rglob(ext):
        recnet[_net_name(path)] = load_netlist(path)

    return recnet

maybe

maybe(
    func: Callable[..., T], /, exc: type[Exception] = Exception
) -> Callable[..., T | None]

Create a safe version of a function that returns None on exceptions.

Wraps a function to catch specified exceptions and return None instead of raising them. This is useful for optional operations or when you want to gracefully handle failures.

Parameters:

Name Type Description Default
func Callable[..., T]

Function to wrap for safe execution.

required
exc type[Exception]

Exception type(s) to catch. Defaults to Exception (catches all).

Exception

Returns:

Type Description
Callable[..., T | None]

Wrapped function that returns None when the specified exception occurs.

Example
# Safe division that returns None for division by zero
safe_divide = maybe(lambda x, y: x / y, ZeroDivisionError)
result = safe_divide(10, 0)  # Returns None instead of raising

# Safe file reading
safe_read = maybe(lambda f: open(f).read(), FileNotFoundError)
content = safe_read("nonexistent.txt")  # Returns None
Source code in src/sax/utils.py
def maybe(
    func: Callable[..., T], /, exc: type[Exception] = Exception
) -> Callable[..., T | None]:
    """Create a safe version of a function that returns None on exceptions.

    Wraps a function to catch specified exceptions and return None instead of
    raising them. This is useful for optional operations or when you want to
    gracefully handle failures.

    Args:
        func: Function to wrap for safe execution.
        exc: Exception type(s) to catch. Defaults to Exception (catches all).

    Returns:
        Wrapped function that returns None when the specified exception occurs.

    Example:
        ```python
        # Safe division that returns None for division by zero
        safe_divide = maybe(lambda x, y: x / y, ZeroDivisionError)
        result = safe_divide(10, 0)  # Returns None instead of raising

        # Safe file reading
        safe_read = maybe(lambda f: open(f).read(), FileNotFoundError)
        content = safe_read("nonexistent.txt")  # Returns None
        ```
    """

    @wraps(func)
    def new_func(*args: Any, **kwargs: Any) -> T | None:  # noqa: ANN401
        try:
            return func(*args, **kwargs)
        except exc:
            return None

    return new_func

merge_dicts

merge_dicts(*dicts: dict) -> dict

Merge multiple dictionaries with support for nested merging.

Recursively merges multiple dictionaries, with later dictionaries taking precedence over earlier ones. Nested dictionaries are merged recursively rather than replaced entirely.

Parameters:

Name Type Description Default
*dicts dict

Variable number of dictionaries to merge.

()

Returns:

Type Description
dict

Merged dictionary with nested structures preserved.

Example
dict1 = {"a": 1, "nested": {"x": 10}}
dict2 = {"b": 2, "nested": {"y": 20}}
merged = merge_dicts(dict1, dict2)
# Result: {"a": 1, "b": 2, "nested": {"x": 10, "y": 20}}
Source code in src/sax/utils.py
def merge_dicts(*dicts: dict) -> dict:
    """Merge multiple dictionaries with support for nested merging.

    Recursively merges multiple dictionaries, with later dictionaries
    taking precedence over earlier ones. Nested dictionaries are merged
    recursively rather than replaced entirely.

    Args:
        *dicts: Variable number of dictionaries to merge.

    Returns:
        Merged dictionary with nested structures preserved.

    Example:
        ```python
        dict1 = {"a": 1, "nested": {"x": 10}}
        dict2 = {"b": 2, "nested": {"y": 20}}
        merged = merge_dicts(dict1, dict2)
        # Result: {"a": 1, "b": 2, "nested": {"x": 10, "y": 20}}
        ```
    """
    num_dicts = len(dicts)

    if num_dicts < 1:
        return {}

    if num_dicts == 1:
        return dict(_generate_merged_dict(dicts[0], {}))

    if len(dicts) == 2:
        return dict(_generate_merged_dict(dicts[0], dicts[1]))

    return merge_dicts(dicts[0], merge_dicts(*dicts[1:]))

mse

Compute mean squared error between two complex arrays.

Parameters:

Name Type Description Default
x ComplexArray

First complex array for comparison.

required
y ComplexArray

Second complex array for comparison.

required

Returns:

Type Description
float

Mean squared error as a float value.

Example
import jax.numpy as jnp

x = jnp.array([1 + 2j, 3 + 4j])
y = jnp.array([1 + 1j, 3 + 3j])
loss = mse(x, y)
Source code in src/sax/loss.py
def mse(x: ComplexArray, y: ComplexArray) -> float:
    """Compute mean squared error between two complex arrays.

    Args:
        x: First complex array for comparison.
        y: Second complex array for comparison.

    Returns:
        Mean squared error as a float value.

    Example:
        ```python
        import jax.numpy as jnp

        x = jnp.array([1 + 2j, 3 + 4j])
        y = jnp.array([1 + 1j, 3 + 3j])
        loss = mse(x, y)
        ```
    """
    return cast(float, (abs(x - y) ** 2).mean())

netlist

netlist(netlist: AnyNetlist, *, top_level_name: str = 'top_level') -> RecursiveNetlist

Convert a netlist to recursive netlist format.

Ensures that the input netlist is in recursive netlist format, which is a dictionary mapping component names to their netlists. If a top-level circuit with the specified name exists, it will be promoted to the top level.

Parameters:

Name Type Description Default
netlist AnyNetlist

Input netlist in any format (flat netlist or recursive netlist).

required
top_level_name str

Name to use for the top-level circuit when converting from flat netlist format. Defaults to "top_level".

'top_level'

Returns:

Type Description
RecursiveNetlist

Recursive netlist dictionary with component name keys.

Raises:

Type Description
ValueError

If the input netlist format is invalid.

Example
# Convert flat netlist to recursive format
flat_net = {
    "instances": {"wg1": {"component": "waveguide"}},
    "ports": {"in": "wg1,in", "out": "wg1,out"},
}
rec_net = netlist(flat_net, top_level_name="my_circuit")
# Result: {"my_circuit": flat_net}
Source code in src/sax/netlists.py
def netlist(
    netlist: sax.AnyNetlist,
    *,
    top_level_name: str = "top_level",
) -> sax.RecursiveNetlist:
    """Convert a netlist to recursive netlist format.

    Ensures that the input netlist is in recursive netlist format, which is a
    dictionary mapping component names to their netlists. If a top-level circuit
    with the specified name exists, it will be promoted to the top level.

    Args:
        netlist: Input netlist in any format (flat netlist or recursive netlist).
        top_level_name: Name to use for the top-level circuit when converting
            from flat netlist format. Defaults to "top_level".

    Returns:
        Recursive netlist dictionary with component name keys.

    Raises:
        ValueError: If the input netlist format is invalid.

    Example:
        ```python
        # Convert flat netlist to recursive format
        flat_net = {
            "instances": {"wg1": {"component": "waveguide"}},
            "ports": {"in": "wg1,in", "out": "wg1,out"},
        }
        rec_net = netlist(flat_net, top_level_name="my_circuit")
        # Result: {"my_circuit": flat_net}
        ```
    """
    if not isinstance(netlist, dict):
        msg = (
            "Invalid argument for `netlist`. "
            "Expected type: dict | Netlist | RecursiveNetlist. "
            f"Got: {type(netlist)}."
        )
        raise TypeError(msg)

    if "instances" in netlist:
        return {top_level_name: cast(sax.Netlist, netlist)}

    recnet = cast(sax.RecursiveNetlist, netlist)
    top_level = recnet.get(top_level_name, None)
    if top_level is None:
        return recnet
    return {top_level_name: top_level, **recnet}

normalization

normalization(x: ComplexArray, axis: int | None = None) -> Normalization

Calculate normalization parameters (mean and std) for an array.

Computes the mean and standard deviation of an array along the specified axis for use in normalization operations.

Parameters:

Name Type Description Default
x ComplexArray

Input array to compute normalization parameters for.

required
axis int | None

Axis along which to compute statistics. If None, computes over the entire array.

None

Returns:

Type Description
Normalization

Normalization object containing mean and standard deviation.

Example
import jax.numpy as jnp

data = jnp.array([[1, 2, 3], [4, 5, 6]])
norm_params = normalization(data, axis=0)
# Computes mean and std along axis 0
Source code in src/sax/utils.py
def normalization(x: sax.ComplexArray, axis: int | None = None) -> Normalization:
    """Calculate normalization parameters (mean and std) for an array.

    Computes the mean and standard deviation of an array along the specified
    axis for use in normalization operations.

    Args:
        x: Input array to compute normalization parameters for.
        axis: Axis along which to compute statistics. If None, computes over
            the entire array.

    Returns:
        Normalization object containing mean and standard deviation.

    Example:
        ```python
        import jax.numpy as jnp

        data = jnp.array([[1, 2, 3], [4, 5, 6]])
        norm_params = normalization(data, axis=0)
        # Computes mean and std along axis 0
        ```
    """
    if axis is None:
        return Normalization(x.mean(), x.std())
    return Normalization(x.mean(axis), x.std(axis))

normalize

normalize(x: ComplexArray, normalization: Normalization) -> ComplexArray

Normalize an array using provided normalization parameters.

Applies z-score normalization: (x - mean) / std.

Parameters:

Name Type Description Default
x ComplexArray

Array to normalize.

required
normalization Normalization

Normalization parameters containing mean and std.

required

Returns:

Type Description
ComplexArray

Normalized array with zero mean and unit standard deviation.

Example
data = jnp.array([1, 2, 3, 4, 5])
norm_params = normalization(data)
normalized = normalize(data, norm_params)
Source code in src/sax/utils.py
def normalize(x: sax.ComplexArray, normalization: Normalization) -> sax.ComplexArray:
    """Normalize an array using provided normalization parameters.

    Applies z-score normalization: (x - mean) / std.

    Args:
        x: Array to normalize.
        normalization: Normalization parameters containing mean and std.

    Returns:
        Normalized array with zero mean and unit standard deviation.

    Example:
        ```python
        data = jnp.array([1, 2, 3, 4, 5])
        norm_params = normalization(data)
        normalized = normalize(data, norm_params)
        ```
    """
    mean, std = normalization
    return (x - mean) / std

read

read(content_or_filename: str | Path | IOLike) -> str

Read content from string, file path, or file-like object.

Flexible content reader that can handle string content directly, file paths, or file-like objects. Automatically detects the input type and reads accordingly.

Parameters:

Name Type Description Default
content_or_filename str | Path | IOLike

Content as string (if contains newlines), file path, or file-like object with read() method.

required

Returns:

Type Description
str

Content as string.

Example
# Direct string content
content = read("line1\\nline2")

# From file path
content = read("config.yaml")

# From file-like object
from io import StringIO

content = read(StringIO("data"))
Source code in src/sax/utils.py
def read(content_or_filename: str | Path | sax.IOLike) -> str:
    r"""Read content from string, file path, or file-like object.

    Flexible content reader that can handle string content directly, file paths,
    or file-like objects. Automatically detects the input type and reads accordingly.

    Args:
        content_or_filename: Content as string (if contains newlines), file path,
            or file-like object with read() method.

    Returns:
        Content as string.

    Example:
        ```python
        # Direct string content
        content = read("line1\\nline2")

        # From file path
        content = read("config.yaml")

        # From file-like object
        from io import StringIO

        content = read(StringIO("data"))
        ```
    """
    if isinstance(content_or_filename, str) and "\n" in content_or_filename:
        return content_or_filename

    if isinstance(content_or_filename, str | Path):
        return Path(content_or_filename).read_text()

    return content_or_filename.read()

reciprocal

reciprocal(sdict: SDict) -> SDict

Make an SDict S-matrix reciprocal by ensuring S[i,j] = S[j,i].

Reciprocity is a fundamental property of passive optical devices where the S-parameter from port i to port j equals the S-parameter from port j to port i. This function enforces reciprocity by copying existing values.

Parameters:

Name Type Description Default
sdict SDict

S-matrix in SDict format.

required

Returns:

Type Description
SDict

Reciprocal S-matrix in SDict format with symmetric port pairs.

Example
# Make a non-reciprocal matrix reciprocal
s_matrix = {("in", "out"): 0.9 + 0.1j}
s_reciprocal = reciprocal(s_matrix)
# Result: {("in", "out"): 0.9+0.1j, ("out", "in"): 0.9+0.1j}
Source code in src/sax/s.py
def reciprocal(sdict: sax.SDict) -> sax.SDict:
    """Make an SDict S-matrix reciprocal by ensuring S[i,j] = S[j,i].

    Reciprocity is a fundamental property of passive optical devices where
    the S-parameter from port i to port j equals the S-parameter from port j
    to port i. This function enforces reciprocity by copying existing values.

    Args:
        sdict: S-matrix in SDict format.

    Returns:
        Reciprocal S-matrix in SDict format with symmetric port pairs.

    Example:
        ```python
        # Make a non-reciprocal matrix reciprocal
        s_matrix = {("in", "out"): 0.9 + 0.1j}
        s_reciprocal = reciprocal(s_matrix)
        # Result: {("in", "out"): 0.9+0.1j, ("out", "in"): 0.9+0.1j}
        ```
    """
    return {
        **{(p1, p2): v for (p1, p2), v in sdict.items()},
        **{(p2, p1): v for (p1, p2), v in sdict.items()},
    }

rename_params

rename_params(model: Model, renamings: dict[str, str]) -> Model

Rename the parameters of a model function.

Creates a new model with renamed parameters while preserving the original functionality and default values.

Parameters:

Name Type Description Default
model Model

Model function to rename parameters for.

required
renamings dict[str, str]

Dictionary mapping old parameter names to new names.

required

Returns:

Type Description
Model

New model function with renamed parameters.

Raises:

Type Description
ValueError

If multiple old names map to the same new name.

Example
def original_model(wavelength=1.55, eff_index=2.4):
    return some_s_matrix


# Rename parameters to standard names
renamed_model = rename_params(
    original_model, {"wavelength": "wl", "eff_index": "neff"}
)
# Now can call: renamed_model(wl=1.55, neff=2.4)
Source code in src/sax/utils.py
def rename_params(model: sax.Model, renamings: dict[str, str]) -> sax.Model:
    """Rename the parameters of a model function.

    Creates a new model with renamed parameters while preserving the original
    functionality and default values.

    Args:
        model: Model function to rename parameters for.
        renamings: Dictionary mapping old parameter names to new names.

    Returns:
        New model function with renamed parameters.

    Raises:
        ValueError: If multiple old names map to the same new name.

    Example:
        ```python
        def original_model(wavelength=1.55, eff_index=2.4):
            return some_s_matrix


        # Rename parameters to standard names
        renamed_model = rename_params(
            original_model, {"wavelength": "wl", "eff_index": "neff"}
        )
        # Now can call: renamed_model(wl=1.55, neff=2.4)
        ```
    """
    reversed_renamings = {v: k for k, v in renamings.items()}
    if len(reversed_renamings) < len(renamings):
        msg = "Multiple old names point to the same new name!"
        raise ValueError(msg)

    if callable(_model := cast(sax.Model, model)):

        @wraps(_model)
        def new_model(**settings: sax.SettingsValue) -> sax.SType:
            old_settings = {
                reversed_renamings.get(k, k): v for k, v in settings.items()
            }
            return _model(**old_settings)

        new_settings = {renamings.get(k, k): v for k, v in get_settings(_model).items()}
        _replace_kwargs(new_model, **new_settings)

        return cast(sax.Model, new_model)

    msg = "rename_params should be used to decorate a Model."
    raise ValueError(msg)

rename_ports

rename_ports(S: SDict, renamings: dict[str, str]) -> SDict
rename_ports(S: SCoo, renamings: dict[str, str]) -> SCoo
rename_ports(S: SDense, renamings: dict[str, str]) -> SDense
rename_ports(S: Model, renamings: dict[str, str]) -> Model
rename_ports(S: SType | Model, renamings: dict[str, str]) -> SType | Model

Rename the ports of an S-matrix or model.

Creates a new S-matrix or model with renamed ports while preserving all S-parameter values and relationships.

Parameters:

Name Type Description Default
S SType | Model

S-matrix in any format or model function to rename ports for.

required
renamings dict[str, str]

Dictionary mapping old port names to new names.

required

Returns:

Type Description
SType | Model

S-matrix or model with renamed ports.

Raises:

Type Description
ValueError

If the input type is not supported for port renaming.

Example
# Rename ports in S-matrix
s_matrix = {("input", "output"): 0.9}
renamed_s = rename_ports(s_matrix, {"input": "in", "output": "out"})
# Result: {("in", "out"): 0.9}


# Rename ports in model
def original_model(wl=1.55):
    return {("input", "output"): 0.9}


renamed_model = rename_ports(original_model, {"input": "in", "output": "out"})
Source code in src/sax/utils.py
def rename_ports(
    S: sax.SType | sax.Model, renamings: dict[str, str]
) -> sax.SType | sax.Model:
    """Rename the ports of an S-matrix or model.

    Creates a new S-matrix or model with renamed ports while preserving
    all S-parameter values and relationships.

    Args:
        S: S-matrix in any format or model function to rename ports for.
        renamings: Dictionary mapping old port names to new names.

    Returns:
        S-matrix or model with renamed ports.

    Raises:
        ValueError: If the input type is not supported for port renaming.

    Example:
        ```python
        # Rename ports in S-matrix
        s_matrix = {("input", "output"): 0.9}
        renamed_s = rename_ports(s_matrix, {"input": "in", "output": "out"})
        # Result: {("in", "out"): 0.9}


        # Rename ports in model
        def original_model(wl=1.55):
            return {("input", "output"): 0.9}


        renamed_model = rename_ports(original_model, {"input": "in", "output": "out"})
        ```
    """
    if callable(model := S):

        @wraps(model)
        def new_model(**settings: sax.SettingsValue) -> sax.SType:
            return rename_ports(model(**settings), renamings)

        return cast(sax.Model, new_model)

    if isinstance(sdict := S, dict):
        return {(renamings[p1], renamings[p2]): v for (p1, p2), v in sdict.items()}

    if len(scoo := cast(sax.SCoo, S)) == 4:
        Si, Sj, Sx, ports_map = scoo
        ports_map = {renamings[p]: i for p, i in ports_map.items()}
        return Si, Sj, Sx, ports_map

    if len(sdense := cast(sax.SDense, S)) == 2:
        Sx, ports_map = sdense
        ports_map = {renamings[p]: i for p, i in ports_map.items()}
        return Sx, ports_map

    msg = f"Cannot rename ports for type {type(S)}"
    raise ValueError(msg)

replace_kwargs

replace_kwargs(func: Callable, **kwargs: SettingsValue) -> None

Change the kwargs signature of a function.

Source code in src/sax/utils.py
def replace_kwargs(func: Callable, **kwargs: sax.SettingsValue) -> None:
    """Change the kwargs signature of a function."""
    sig = inspect.signature(func)
    settings = [
        inspect.Parameter(k, inspect.Parameter.KEYWORD_ONLY, default=v)
        for k, v in kwargs.items()
    ]
    func.__signature__ = sig.replace(parameters=settings)  # type: ignore # noqa: PGH003

scoo

scoo(S: Model) -> SCooModel
scoo(S: SType) -> SCoo
scoo(S: Model | SType) -> SCooModel | SCoo

Convert an S-matrix to SCoo (coordinate) format.

SCoo format represents S-parameters in coordinate (sparse) format with separate arrays for row indices (Si), column indices (Sj), values (Sx), and port mapping. This format is memory-efficient for sparse matrices.

Parameters:

Name Type Description Default
S Model | SType

S-matrix in SDict, SDense format, or a model that returns such matrices.

required

Returns:

Type Description
SCooModel | SCoo

S-matrix in SCoo format (Si, Sj, Sx, port_map), or a model that returns

SCooModel | SCoo

SCoo format.

Raises:

Type Description
ValueError

If the input cannot be converted to SCoo format.

Example
# Convert from dict format
sdict_matrix = {("in", "out"): 0.9 + 0.1j, ("out", "in"): 0.9 + 0.1j}
Si, Sj, Sx, port_map = scoo(sdict_matrix)


# Convert a model
def my_model(wl=1.55):
    return some_sdict_matrix


scoo_model = scoo(my_model)
Source code in src/sax/s.py
def scoo(S: sax.Model | sax.SType) -> sax.SCooModel | sax.SCoo:
    """Convert an S-matrix to SCoo (coordinate) format.

    SCoo format represents S-parameters in coordinate (sparse) format with
    separate arrays for row indices (Si), column indices (Sj), values (Sx),
    and port mapping. This format is memory-efficient for sparse matrices.

    Args:
        S: S-matrix in SDict, SDense format, or a model that returns such matrices.

    Returns:
        S-matrix in SCoo format (Si, Sj, Sx, port_map), or a model that returns
        SCoo format.

    Raises:
        ValueError: If the input cannot be converted to SCoo format.

    Example:
        ```python
        # Convert from dict format
        sdict_matrix = {("in", "out"): 0.9 + 0.1j, ("out", "in"): 0.9 + 0.1j}
        Si, Sj, Sx, port_map = scoo(sdict_matrix)


        # Convert a model
        def my_model(wl=1.55):
            return some_sdict_matrix


        scoo_model = scoo(my_model)
        ```
    """
    if callable(_model := S):

        @wraps(_model)
        def model(**kwargs: sax.SettingsValue) -> sax.SCoo:
            return scoo(_model(**kwargs))

        return model

    if isinstance(_sdict := S, dict):
        return _sdict_to_scoo(_sdict)

    if len(_scoo := cast(tuple, S)) == 4:
        return _scoo

    if len(_sdense := cast(tuple, S)) == 2:
        return _sdense_to_scoo(*_sdense)

    msg = f"Could not convert S-matrix to scoo. Got: {S!r}."
    raise ValueError(msg)

sdense

sdense(S: Model) -> SDenseModel
sdense(S: SType) -> SDense
sdense(S: SType | Model) -> SDenseModel | SDense

Convert an S-matrix to SDense (dense matrix) format.

SDense format represents S-parameters as a dense complex matrix with an associated port mapping. This format is efficient for dense matrices and enables fast linear algebra operations.

Parameters:

Name Type Description Default
S SType | Model

S-matrix in SDict, SCoo format, or a model that returns such matrices.

required

Returns:

Type Description
SDenseModel | SDense

S-matrix in SDense format (matrix, port_map), or a model that returns

SDenseModel | SDense

SDense format.

Raises:

Type Description
ValueError

If the input cannot be converted to SDense format.

Example
# Convert from dict format
sdict_matrix = {("in", "out"): 0.9 + 0.1j, ("out", "in"): 0.9 + 0.1j}
matrix, port_map = sdense(sdict_matrix)


# Convert a model
def my_model(wl=1.55):
    return some_sdict_matrix


sdense_model = sdense(my_model)
Source code in src/sax/s.py
def sdense(S: sax.SType | sax.Model) -> sax.SDenseModel | sax.SDense:
    """Convert an S-matrix to SDense (dense matrix) format.

    SDense format represents S-parameters as a dense complex matrix with an
    associated port mapping. This format is efficient for dense matrices and
    enables fast linear algebra operations.

    Args:
        S: S-matrix in SDict, SCoo format, or a model that returns such matrices.

    Returns:
        S-matrix in SDense format (matrix, port_map), or a model that returns
        SDense format.

    Raises:
        ValueError: If the input cannot be converted to SDense format.

    Example:
        ```python
        # Convert from dict format
        sdict_matrix = {("in", "out"): 0.9 + 0.1j, ("out", "in"): 0.9 + 0.1j}
        matrix, port_map = sdense(sdict_matrix)


        # Convert a model
        def my_model(wl=1.55):
            return some_sdict_matrix


        sdense_model = sdense(my_model)
        ```
    """
    if callable(_model := S):

        @wraps(_model)
        def model(**kwargs: sax.SettingsValue) -> sax.SDense:
            return sdense(_model(**kwargs))

        return model

    if isinstance(_sdict := S, dict):
        return _sdict_to_sdense(_sdict)

    if len(_scoo := cast(tuple, S)) == 4:
        return _scoo_to_sdense(*_scoo)

    if len(_sdense := cast(tuple, S)) == 2:
        return _sdense

    msg = f"Could not convert S-matrix to sdense. Got: {S!r}."
    raise ValueError(msg)

sdict

sdict(S: Model) -> SDictModel
sdict(S: SType) -> SDict
sdict(S: Model | SType) -> SDictModel | SDict

Convert an S-matrix to SDict (dictionary) format.

SDict format represents S-parameters as a dictionary with port pair tuples as keys and complex values as entries. This is the most flexible format for sparse matrices and custom port naming.

Parameters:

Name Type Description Default
S Model | SType

S-matrix in SCoo, SDense format, or a model that returns such matrices.

required

Returns:

Type Description
SDictModel | SDict

S-matrix in SDict format, or a model that returns SDict format.

Raises:

Type Description
ValueError

If the input cannot be converted to SDict format.

Example
# Convert from other formats
scoo_matrix = (Si, Sj, Sx, port_map)
sdict_matrix = sdict(scoo_matrix)


# Convert a model
def my_model(wl=1.55):
    return some_sdense_matrix


sdict_model = sdict(my_model)
Source code in src/sax/s.py
def sdict(S: sax.Model | sax.SType) -> sax.SDictModel | sax.SDict:
    """Convert an S-matrix to SDict (dictionary) format.

    SDict format represents S-parameters as a dictionary with port pair tuples
    as keys and complex values as entries. This is the most flexible format for
    sparse matrices and custom port naming.

    Args:
        S: S-matrix in SCoo, SDense format, or a model that returns such matrices.

    Returns:
        S-matrix in SDict format, or a model that returns SDict format.

    Raises:
        ValueError: If the input cannot be converted to SDict format.

    Example:
        ```python
        # Convert from other formats
        scoo_matrix = (Si, Sj, Sx, port_map)
        sdict_matrix = sdict(scoo_matrix)


        # Convert a model
        def my_model(wl=1.55):
            return some_sdense_matrix


        sdict_model = sdict(my_model)
        ```
    """
    if callable(_model := S):

        @wraps(_model)
        def model(**kwargs: sax.SettingsValue) -> sax.SDict:
            return sdict(_model(**kwargs))

        return model

    if isinstance(_sdict := S, dict):
        return _sdict

    if len(_scoo := cast(tuple, S)) == 4:
        return _scoo_to_sdict(*_scoo)

    if len(_sdense := cast(tuple, S)) == 2:
        return _sdense_to_sdict(*_sdense)

    msg = f"Could not convert S-matrix to sdict. Got: {S!r}."
    raise ValueError(msg)

set_port_naming_strategy

set_port_naming_strategy(strategy: PortNamingStrategy) -> PortNamingStrategy

Set the port naming strategy for the default SAX models.

This function configures the global port naming convention used by SAX model functions. The strategy affects how ports are named in generated S-matrices and impacts netlist compatibility with different simulation tools.

Parameters:

Name Type Description Default
strategy PortNamingStrategy

Port naming strategy to use. Options are: - "optical": Uses optical convention (o1, o2, o3, ...) - "inout": Uses input/output convention (in0, in1, out0, out1, ...)

required

Returns:

Type Description
PortNamingStrategy

The strategy that was set (for confirmation).

Raises:

Type Description
ValueError

If the strategy is not one of the valid options.

Examples:

Switch to optical naming:

import sax.models

sax.models.set_port_naming_strategy("optical")
# Now models will use o1, o2, o3, ... port names

Switch to input/output naming:

sax.models.set_port_naming_strategy("inout")
# Now models will use in0, in1, out0, out1, ... port names

Check current strategy:

current = sax.models.get_port_naming_strategy()
print(f"Current strategy: {current}")
Note

This is a global setting that affects all subsequently created model instances. Existing model instances retain their original port naming. The "inout" convention is more intuitive for circuit simulation, while "optical" follows traditional optical component naming.

Source code in src/sax/ports.py
def set_port_naming_strategy(strategy: PortNamingStrategy) -> PortNamingStrategy:
    """Set the port naming strategy for the default SAX models.

    This function configures the global port naming convention used by SAX model
    functions. The strategy affects how ports are named in generated S-matrices
    and impacts netlist compatibility with different simulation tools.

    Args:
        strategy: Port naming strategy to use. Options are:
            - "optical": Uses optical convention (o1, o2, o3, ...)
            - "inout": Uses input/output convention (in0, in1, out0, out1, ...)

    Returns:
        The strategy that was set (for confirmation).

    Raises:
        ValueError: If the strategy is not one of the valid options.

    Examples:
        Switch to optical naming:

        ```python
        import sax.models

        sax.models.set_port_naming_strategy("optical")
        # Now models will use o1, o2, o3, ... port names
        ```

        Switch to input/output naming:

        ```python
        sax.models.set_port_naming_strategy("inout")
        # Now models will use in0, in1, out0, out1, ... port names
        ```

        Check current strategy:

        ```python
        current = sax.models.get_port_naming_strategy()
        print(f"Current strategy: {current}")
        ```

    Note:
        This is a global setting that affects all subsequently created model
        instances. Existing model instances retain their original port naming.
        The "inout" convention is more intuitive for circuit simulation, while
        "optical" follows traditional optical component naming.
    """
    global PORT_NAMING_STRATEGY  # noqa: PLW0603
    _strategy = str(strategy).lower()
    if _strategy not in ["optical", "inout"]:
        msg = f"Invalid port naming strategy: {strategy}"
        raise ValueError(msg)
    PORT_NAMING_STRATEGY = strategy
    return strategy

singlemode

singlemode(S: SDictModel, mode: str = DEFAULT_MODE) -> SDictModelSM
singlemode(S: SCooModel, mode: str = DEFAULT_MODE) -> SCooModelSM
singlemode(S: SDenseModel, mode: str = DEFAULT_MODE) -> SDenseModelSM
singlemode(S: SDict, mode: str = DEFAULT_MODE) -> SDictSM
singlemode(S: SCoo, mode: str = DEFAULT_MODE) -> SCooSM
singlemode(S: SDense, mode: str = DEFAULT_MODE) -> SDenseSM
singlemode(S: SType | Model, mode: Mode = DEFAULT_MODE) -> STypeSM | ModelSM

Convert a multimode S-matrix or model to single-mode.

Extracts a single optical mode from a multimode S-matrix, effectively filtering out all other modes. The resulting single-mode S-matrix will have standard port naming (without @mode suffix).

Parameters:

Name Type Description Default
S SType | Model

Multimode S-matrix in any format or a model that returns such matrices.

required
mode Mode

The optical mode to extract (e.g., "TE", "TM"). Defaults to "TE".

DEFAULT_MODE

Returns:

Type Description
STypeSM | ModelSM

Single-mode S-matrix with standard port naming, or a model that returns

STypeSM | ModelSM

such matrices.

Raises:

Type Description
ValueError

If the input cannot be converted to single-mode.

Example
# Extract TE mode from multimode S-matrix
s_mm = {
    ("in@TE", "out@TE"): 0.9,
    ("in@TM", "out@TM"): 0.8,
    ("in@TE", "out@TM"): 0.1,
}
s_te = singlemode(s_mm, mode="TE")
# Result: {("in", "out"): 0.9}


# Convert a multimode model to single-mode
def multimode_model(wl=1.55):
    return multimode_s_matrix


te_model = singlemode(multimode_model, mode="TE")
Source code in src/sax/multimode.py
def singlemode(
    S: sax.SType | sax.Model, mode: sax.Mode = DEFAULT_MODE
) -> sax.STypeSM | sax.ModelSM:
    """Convert a multimode S-matrix or model to single-mode.

    Extracts a single optical mode from a multimode S-matrix, effectively
    filtering out all other modes. The resulting single-mode S-matrix will
    have standard port naming (without @mode suffix).

    Args:
        S: Multimode S-matrix in any format or a model that returns such matrices.
        mode: The optical mode to extract (e.g., "TE", "TM"). Defaults to "TE".

    Returns:
        Single-mode S-matrix with standard port naming, or a model that returns
        such matrices.

    Raises:
        ValueError: If the input cannot be converted to single-mode.

    Example:
        ```python
        # Extract TE mode from multimode S-matrix
        s_mm = {
            ("in@TE", "out@TE"): 0.9,
            ("in@TM", "out@TM"): 0.8,
            ("in@TE", "out@TM"): 0.1,
        }
        s_te = singlemode(s_mm, mode="TE")
        # Result: {("in", "out"): 0.9}


        # Convert a multimode model to single-mode
        def multimode_model(wl=1.55):
            return multimode_s_matrix


        te_model = singlemode(multimode_model, mode="TE")
        ```
    """
    if (model := sax.try_into[sax.Model](S)) is not None:

        @wraps(model)
        def new_model(**params: sax.SettingsValue) -> sax.STypeSM:
            return singlemode(model(**params), mode=mode)

        return cast(sax.ModelSM, new_model)

    if (sdict_mm := sax.try_into[sax.SDictMM](S)) is not None:
        return _singlemode_sdict(sdict_mm, mode=mode)

    if (scoo_mm := sax.try_into[sax.SCooMM](S)) is not None:
        return _singlemode_scoo(scoo_mm, mode=mode)

    if (sdense_mm := sax.try_into[sax.SDenseMM](S)) is not None:
        return _singlemode_sdense(sdense_mm, mode=mode)

    if (s_sm := sax.try_into[sax.STypeSM](S)) is not None:
        return s_sm

    msg = f"Cannot convert {S!r} to singlemode. Unknown SType."
    raise ValueError(msg)

unflatten_dict

unflatten_dict(dic: dict[str, Any], sep: str = ',') -> dict[str, Any]

Unflatten a dictionary by splitting keys and creating nested structure.

Converts a flattened dictionary back to nested form by splitting keys on the separator and creating the nested hierarchy.

Parameters:

Name Type Description Default
dic dict[str, Any]

Flattened dictionary to unflatten.

required
sep str

Separator used in the flattened keys. Defaults to ",".

','

Returns:

Type Description
dict[str, Any]

Nested dictionary with original structure restored.

Example
flat = {"a,b": 1, "a,c,d": 2, "e": 3}
nested = unflatten_dict(flat)
# Result: {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
Source code in src/sax/utils.py
def unflatten_dict(dic: dict[str, Any], sep: str = ",") -> dict[str, Any]:
    """Unflatten a dictionary by splitting keys and creating nested structure.

    Converts a flattened dictionary back to nested form by splitting keys
    on the separator and creating the nested hierarchy.

    Args:
        dic: Flattened dictionary to unflatten.
        sep: Separator used in the flattened keys. Defaults to ",".

    Returns:
        Nested dictionary with original structure restored.

    Example:
        ```python
        flat = {"a,b": 1, "a,c,d": 2, "e": 3}
        nested = unflatten_dict(flat)
        # Result: {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
        ```
    """
    # from: https://gist.github.com/fmder/494aaa2dd6f8c428cede
    items = {}

    for k, v in dic.items():
        keys = k.split(sep)
        sub_items = items
        for ki in keys[:-1]:
            if ki in sub_items:
                sub_items = sub_items[ki]
            else:
                sub_items[ki] = {}
                sub_items = sub_items[ki]

        sub_items[keys[-1]] = v

    return items

update_settings

update_settings(
    settings: Settings, *compnames: str, **kwargs: SettingsValue
) -> Settings

Update a nested settings dictionary with new parameter values.

Updates settings for specific components or globally. Supports nested parameter dictionaries and selective component updates.

Parameters:

Name Type Description Default
settings Settings

Original settings dictionary to update.

required
*compnames str

Component names to update. If empty, updates all components.

()
**kwargs SettingsValue

Parameter values to update.

{}

Returns:

Type Description
Settings

Updated settings dictionary (does not modify original).

Note
  • This operation never updates the given settings dictionary in place.
  • Any non-float keyword arguments will be silently ignored.
  • Even though it's possible to update parameter dictionaries in place, this function is convenient to apply certain parameters (e.g. wavelength 'wl' or temperature 'T') globally.
Example
settings = {
    "wg1": {"length": 10.0, "neff": 2.4},
    "wg2": {"length": 20.0, "neff": 2.4},
}
# Update all components
updated = update_settings(settings, wl=1.55)

# Update specific component
updated = update_settings(settings, "wg1", length=15.0)
Source code in src/sax/utils.py
def update_settings(
    settings: sax.Settings, *compnames: str, **kwargs: sax.SettingsValue
) -> sax.Settings:
    """Update a nested settings dictionary with new parameter values.

    Updates settings for specific components or globally. Supports nested
    parameter dictionaries and selective component updates.

    Args:
        settings: Original settings dictionary to update.
        *compnames: Component names to update. If empty, updates all components.
        **kwargs: Parameter values to update.

    Returns:
        Updated settings dictionary (does not modify original).

    Note:
        - This operation never updates the given settings dictionary in place.
        - Any non-float keyword arguments will be silently ignored.
        - Even though it's possible to update parameter dictionaries in place,
          this function is convenient to apply certain parameters (e.g. wavelength
          'wl' or temperature 'T') globally.

    Example:
        ```python
        settings = {
            "wg1": {"length": 10.0, "neff": 2.4},
            "wg2": {"length": 20.0, "neff": 2.4},
        }
        # Update all components
        updated = update_settings(settings, wl=1.55)

        # Update specific component
        updated = update_settings(settings, "wg1", length=15.0)
        ```
    """
    _settings = {}
    if not compnames:
        for k, v in settings.items():
            if isinstance(v, dict):
                _settings[k] = update_settings(v, **kwargs)
            elif k in kwargs:
                _settings[k] = _try_complex_float(kwargs[k])
            else:
                _settings[k] = _try_complex_float(v)
    else:
        for k, v in settings.items():
            if isinstance(v, dict):
                if k == compnames[0]:
                    _settings[k] = update_settings(v, *compnames[1:], **kwargs)
                else:
                    _settings[k] = v
            else:
                _settings[k] = _try_complex_float(v)
    return _settings

wl_c

wl_c(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_C_MIN,
    wl_max: float = WL_C_MAX,
) -> FloatArray1D

Generate wavelength array in the C-band (Conventional band: 1530-1565 nm).

The C-band is the most commonly used wavelength range in optical communications due to the low attenuation characteristics of standard single-mode fiber.

Parameters:

Name Type Description Default
step float

Wavelength step size in μm. Used when num is None.

DEFAULT_WL_STEP
num int | None

Number of wavelength points. If provided, step is ignored.

None
wl_min float

Minimum wavelength in μm. Defaults to C-band minimum.

WL_C_MIN
wl_max float

Maximum wavelength in μm. Defaults to C-band maximum.

WL_C_MAX

Returns:

Type Description
FloatArray1D

1D array of wavelengths in μm.

Example
# Generate C-band wavelengths with default step
wl = wl_c()
# Generate 100 equally spaced points in C-band
wl = wl_c(num=100)
# High resolution C-band scan
wl = wl_c(step=0.00001)
Source code in src/sax/constants.py
def wl_c(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_C_MIN,
    wl_max: float = WL_C_MAX,
) -> FloatArray1D:
    """Generate wavelength array in the C-band (Conventional band: 1530-1565 nm).

    The C-band is the most commonly used wavelength range in optical communications
    due to the low attenuation characteristics of standard single-mode fiber.

    Args:
        step: Wavelength step size in μm. Used when num is None.
        num: Number of wavelength points. If provided, step is ignored.
        wl_min: Minimum wavelength in μm. Defaults to C-band minimum.
        wl_max: Maximum wavelength in μm. Defaults to C-band maximum.

    Returns:
        1D array of wavelengths in μm.

    Example:
        ```python
        # Generate C-band wavelengths with default step
        wl = wl_c()
        # Generate 100 equally spaced points in C-band
        wl = wl_c(num=100)
        # High resolution C-band scan
        wl = wl_c(step=0.00001)
        ```
    """
    return _wl(step=step, num=num, wl_min=wl_min, wl_max=wl_max)

wl_e

wl_e(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_E_MIN,
    wl_max: float = WL_E_MAX,
) -> FloatArray1D

Generate wavelength array in the E-band (Extended band: 1360-1460 nm).

Parameters:

Name Type Description Default
step float

Wavelength step size in μm. Used when num is None.

DEFAULT_WL_STEP
num int | None

Number of wavelength points. If provided, step is ignored.

None
wl_min float

Minimum wavelength in μm. Defaults to E-band minimum.

WL_E_MIN
wl_max float

Maximum wavelength in μm. Defaults to E-band maximum.

WL_E_MAX

Returns:

Type Description
FloatArray1D

1D array of wavelengths in μm.

Example
# Generate E-band wavelengths with default step
wl = wl_e()
# Generate 50 equally spaced points in E-band
wl = wl_e(num=50)
Source code in src/sax/constants.py
def wl_e(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_E_MIN,
    wl_max: float = WL_E_MAX,
) -> FloatArray1D:
    """Generate wavelength array in the E-band (Extended band: 1360-1460 nm).

    Args:
        step: Wavelength step size in μm. Used when num is None.
        num: Number of wavelength points. If provided, step is ignored.
        wl_min: Minimum wavelength in μm. Defaults to E-band minimum.
        wl_max: Maximum wavelength in μm. Defaults to E-band maximum.

    Returns:
        1D array of wavelengths in μm.

    Example:
        ```python
        # Generate E-band wavelengths with default step
        wl = wl_e()
        # Generate 50 equally spaced points in E-band
        wl = wl_e(num=50)
        ```
    """
    return _wl(step=step, num=num, wl_min=wl_min, wl_max=wl_max)

wl_l

wl_l(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_L_MIN,
    wl_max: float = WL_L_MAX,
) -> FloatArray1D

Generate wavelength array in the L-band (Long band: 1565-1625 nm).

Parameters:

Name Type Description Default
step float

Wavelength step size in μm. Used when num is None.

DEFAULT_WL_STEP
num int | None

Number of wavelength points. If provided, step is ignored.

None
wl_min float

Minimum wavelength in μm. Defaults to L-band minimum.

WL_L_MIN
wl_max float

Maximum wavelength in μm. Defaults to L-band maximum.

WL_L_MAX

Returns:

Type Description
FloatArray1D

1D array of wavelengths in μm.

Example
# Generate L-band wavelengths with default step
wl = wl_l()
# Generate 100 equally spaced points in L-band
wl = wl_l(num=100)
Source code in src/sax/constants.py
def wl_l(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_L_MIN,
    wl_max: float = WL_L_MAX,
) -> FloatArray1D:
    """Generate wavelength array in the L-band (Long band: 1565-1625 nm).

    Args:
        step: Wavelength step size in μm. Used when num is None.
        num: Number of wavelength points. If provided, step is ignored.
        wl_min: Minimum wavelength in μm. Defaults to L-band minimum.
        wl_max: Maximum wavelength in μm. Defaults to L-band maximum.

    Returns:
        1D array of wavelengths in μm.

    Example:
        ```python
        # Generate L-band wavelengths with default step
        wl = wl_l()
        # Generate 100 equally spaced points in L-band
        wl = wl_l(num=100)
        ```
    """
    return _wl(step=step, num=num, wl_min=wl_min, wl_max=wl_max)

wl_o

wl_o(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_O_MIN,
    wl_max: float = WL_O_MAX,
) -> FloatArray1D

Generate wavelength array in the O-band (Original band: 1260-1360 nm).

Parameters:

Name Type Description Default
step float

Wavelength step size in μm. Used when num is None.

DEFAULT_WL_STEP
num int | None

Number of wavelength points. If provided, step is ignored.

None
wl_min float

Minimum wavelength in μm. Defaults to O-band minimum.

WL_O_MIN
wl_max float

Maximum wavelength in μm. Defaults to O-band maximum.

WL_O_MAX

Returns:

Type Description
FloatArray1D

1D array of wavelengths in μm.

Example
# Generate O-band wavelengths with default step
wl = wl_o()
# Generate 100 equally spaced points in O-band
wl = wl_o(num=100)
# Custom range with specific step
wl = wl_o(step=0.001, wl_min=1.27, wl_max=1.35)
Source code in src/sax/constants.py
def wl_o(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_O_MIN,
    wl_max: float = WL_O_MAX,
) -> FloatArray1D:
    """Generate wavelength array in the O-band (Original band: 1260-1360 nm).

    Args:
        step: Wavelength step size in μm. Used when num is None.
        num: Number of wavelength points. If provided, step is ignored.
        wl_min: Minimum wavelength in μm. Defaults to O-band minimum.
        wl_max: Maximum wavelength in μm. Defaults to O-band maximum.

    Returns:
        1D array of wavelengths in μm.

    Example:
        ```python
        # Generate O-band wavelengths with default step
        wl = wl_o()
        # Generate 100 equally spaced points in O-band
        wl = wl_o(num=100)
        # Custom range with specific step
        wl = wl_o(step=0.001, wl_min=1.27, wl_max=1.35)
        ```
    """
    return _wl(step=step, num=num, wl_min=wl_min, wl_max=wl_max)

wl_s

wl_s(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_S_MIN,
    wl_max: float = WL_S_MAX,
) -> FloatArray1D

Generate wavelength array in the S-band (Short band: 1460-1530 nm).

Parameters:

Name Type Description Default
step float

Wavelength step size in μm. Used when num is None.

DEFAULT_WL_STEP
num int | None

Number of wavelength points. If provided, step is ignored.

None
wl_min float

Minimum wavelength in μm. Defaults to S-band minimum.

WL_S_MIN
wl_max float

Maximum wavelength in μm. Defaults to S-band maximum.

WL_S_MAX

Returns:

Type Description
FloatArray1D

1D array of wavelengths in μm.

Example
# Generate S-band wavelengths with default step
wl = wl_s()
# Generate 100 equally spaced points in S-band
wl = wl_s(num=100)
Source code in src/sax/constants.py
def wl_s(
    *,
    step: float = DEFAULT_WL_STEP,
    num: int | None = None,
    wl_min: float = WL_S_MIN,
    wl_max: float = WL_S_MAX,
) -> FloatArray1D:
    """Generate wavelength array in the S-band (Short band: 1460-1530 nm).

    Args:
        step: Wavelength step size in μm. Used when num is None.
        num: Number of wavelength points. If provided, step is ignored.
        wl_min: Minimum wavelength in μm. Defaults to S-band minimum.
        wl_max: Maximum wavelength in μm. Defaults to S-band maximum.

    Returns:
        1D array of wavelengths in μm.

    Example:
        ```python
        # Generate S-band wavelengths with default step
        wl = wl_s()
        # Generate 100 equally spaced points in S-band
        wl = wl_s(num=100)
        ```
    """
    return _wl(step=step, num=num, wl_min=wl_min, wl_max=wl_max)