Skip to content

models

SAX Models.

Modules:

Name Description
bends

SAX Bend Models.

couplers

SAX Coupler Models.

crossings

SAX Crossing Models.

factories

SAX Default Models.

mmis

SAX MMI Models.

splitters

SAX Default Models.

straight

SAX Default Models.

Functions:

Name Description
attenuator

Simple optical attenuator model.

bend

Simple waveguide bend model.

copier

Generate a power copying/amplification device model.

coupler

Dispersive directional coupler model.

coupler_ideal

Ideal 2x2 directional coupler model.

crossing_ideal

Ideal waveguide crossing model.

grating_coupler

Grating coupler model for fiber-chip coupling.

mmi1x2

Realistic 1x2 MMI splitter model with dispersion and loss.

mmi1x2_ideal

Ideal 1x2 multimode interference (MMI) splitter model.

mmi2x2

Realistic 2x2 MMI coupler model with dispersion and asymmetry.

mmi2x2_ideal

Ideal 2x2 multimode interference (MMI) coupler model.

model_2port

Generate a general 2-port model factory with custom port names.

model_3port

Generate a general 3-port model factory with custom port names.

model_4port

Generate a general 4-port model factory with custom port names.

passthru

Generate a multi-port pass-through device model.

phase_shifter

Simple voltage-controlled phase shifter model.

splitter_ideal

Ideal 1x2 power splitter model.

unitary

Generate a unitary N×M optical device model.

attenuator

attenuator(*, loss: FloatArrayLike = 0.0) -> SDict

Simple optical attenuator model.

This function models a variable optical attenuator (VOA) that provides controllable attenuation without phase change. The device reduces optical power by a specified amount while maintaining the phase of the transmitted light.

The attenuator is reciprocal and lossless from an S-matrix perspective, meaning the specified loss represents intentional signal reduction rather than parasitic loss that would break unitarity.

Parameters:

Name Type Description Default
loss FloatArrayLike

Attenuation in decibels (dB). Positive values represent attenuation: - loss = 0.0: No attenuation (unity transmission) - loss = 3.0: 3 dB attenuation (50% power transmission) - loss = 10.0: 10 dB attenuation (10% power transmission) Can be wavelength-dependent for spectral filtering applications. Defaults to 0.0 dB.

0.0

Returns:

Type Description
SDict

S-matrix dictionary containing the complex transmission coefficient.

Examples:

Fixed attenuator:

import sax

# 6 dB attenuator
s_matrix = sax.models.attenuator(loss=6.0)
transmission = abs(s_matrix[("in0", "out0")]) ** 2
print(f"Power transmission: {transmission:.3f}")  # Should be ~0.251

# Verify in dB
loss_dB = -10 * jnp.log10(transmission)
print(f"Loss: {loss_dB:.1f} dB")  # Should be 6.0 dB

Variable attenuator:

import numpy as np

losses = np.linspace(0, 20, 101)  # 0 to 20 dB
transmissions = []
for loss_val in losses:
    s = sax.models.attenuator(loss=loss_val)
    transmissions.append(abs(s[("in0", "out0")]) ** 2)

Wavelength-dependent attenuator (filter):

wavelengths = np.linspace(1.5, 1.6, 101)
# Gaussian spectral filter
center_wl = 1.55
bandwidth = 0.01
spectral_loss = 20 * np.exp(-(((wavelengths - center_wl) / bandwidth) ** 2))
s_matrices = sax.models.attenuator(loss=spectral_loss)
Note

This model represents an ideal attenuator with: - No wavelength dependence (unless explicitly provided) - No reflection at input or output - No phase change - Perfect reciprocity

Real attenuators may exhibit: - Wavelength-dependent loss variation - Small amounts of reflection - Polarization-dependent loss (PDL) - Temperature sensitivity

For more complex filtering behavior, consider combining multiple components or using specialized filter models.

Source code in src/sax/models/straight.py
@jax.jit
@validate_call
def attenuator(*, loss: sax.FloatArrayLike = 0.0) -> sax.SDict:
    """Simple optical attenuator model.

    This function models a variable optical attenuator (VOA) that provides
    controllable attenuation without phase change. The device reduces optical
    power by a specified amount while maintaining the phase of the transmitted
    light.

    The attenuator is reciprocal and lossless from an S-matrix perspective,
    meaning the specified loss represents intentional signal reduction rather
    than parasitic loss that would break unitarity.

    Args:
        loss: Attenuation in decibels (dB). Positive values represent attenuation:
            - loss = 0.0: No attenuation (unity transmission)
            - loss = 3.0: 3 dB attenuation (50% power transmission)
            - loss = 10.0: 10 dB attenuation (10% power transmission)
            Can be wavelength-dependent for spectral filtering applications.
            Defaults to 0.0 dB.

    Returns:
        S-matrix dictionary containing the complex transmission coefficient.

    Examples:
        Fixed attenuator:

        ```python
        import sax

        # 6 dB attenuator
        s_matrix = sax.models.attenuator(loss=6.0)
        transmission = abs(s_matrix[("in0", "out0")]) ** 2
        print(f"Power transmission: {transmission:.3f}")  # Should be ~0.251

        # Verify in dB
        loss_dB = -10 * jnp.log10(transmission)
        print(f"Loss: {loss_dB:.1f} dB")  # Should be 6.0 dB
        ```

        Variable attenuator:

        ```python
        import numpy as np

        losses = np.linspace(0, 20, 101)  # 0 to 20 dB
        transmissions = []
        for loss_val in losses:
            s = sax.models.attenuator(loss=loss_val)
            transmissions.append(abs(s[("in0", "out0")]) ** 2)
        ```

        Wavelength-dependent attenuator (filter):

        ```python
        wavelengths = np.linspace(1.5, 1.6, 101)
        # Gaussian spectral filter
        center_wl = 1.55
        bandwidth = 0.01
        spectral_loss = 20 * np.exp(-(((wavelengths - center_wl) / bandwidth) ** 2))
        s_matrices = sax.models.attenuator(loss=spectral_loss)
        ```

    Note:
        This model represents an ideal attenuator with:
        - No wavelength dependence (unless explicitly provided)
        - No reflection at input or output
        - No phase change
        - Perfect reciprocity

        Real attenuators may exhibit:
        - Wavelength-dependent loss variation
        - Small amounts of reflection
        - Polarization-dependent loss (PDL)
        - Temperature sensitivity

        For more complex filtering behavior, consider combining multiple
        components or using specialized filter models.
    """
    transmission = jnp.asarray(10 ** (-loss / 20), dtype=complex)
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.out0): transmission,
        }
    )

bend

bend(
    wl: FloatArrayLike = 1.55,
    wl0: FloatArrayLike = 1.55,
    neff: FloatArrayLike = 2.34,
    ng: FloatArrayLike = 3.4,
    length: FloatArrayLike = 10.0,
    loss_dB_cm: FloatArrayLike = 0.1,
) -> SDict

Simple waveguide bend model.

This function models a simple bend in a waveguide by treating it as an equivalent straight waveguide with the same effective length and properties. The bend is approximated using the straight waveguide model, inheriting all its dispersion and loss characteristics.

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

1.55
wl0 FloatArrayLike

Reference wavelength in micrometers used for dispersion calculation. This is typically the design wavelength where neff is specified. Defaults to 1.55 μm.

1.55
neff FloatArrayLike

Effective refractive index at the reference wavelength. This value represents the fundamental mode effective index and determines the phase velocity. Defaults to 2.34 (typical for silicon).

2.34
ng FloatArrayLike

Group refractive index at the reference wavelength. Used to calculate chromatic dispersion: ng = neff - lambda * d(neff)/d(lambda). Typically ng > neff for normal dispersion. Defaults to 3.4.

3.4
length FloatArrayLike

Physical length of the waveguide in micrometers. Determines both the total phase accumulation and loss. Defaults to 10.0 μm.

10.0
loss_dB_cm FloatArrayLike

Propagation loss in dB/cm. Includes material absorption, scattering losses, and other loss mechanisms. Set to 0.0 for lossless modeling. Defaults to 0.0 dB/cm.

0.1

Returns:

Type Description
SDict

The bend s-matrix

Examples:

Basic bend simulation:

import numpy as np
import sax

# Single wavelength simulation
s_matrix = sax.models.bend(wl=1.55, length=20.0, loss_dB_cm=0.2)
print(f"Transmission: {s_matrix[('in0', 'out0')]}")

Multi-wavelength analysis:

wavelengths = np.linspace(1.5, 1.6, 101)
s_matrices = sax.models.bend(
    wl=wavelengths, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5
)
transmission = np.abs(s_matrices[("in0", "out0")]) ** 2
               o2/out0
               |
              /
             /
o1/in0 _____/
# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wavelengths = np.linspace(1.5, 1.6, 101)
s = sax.models.bend(wl=wavelengths, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5)
transmission = np.abs(s[("o1", "o2")]) ** 2

plt.figure()
plt.plot(wavelengths, transmission)
plt.xlabel("Wavelength (μm)")
plt.ylabel("Transmission")
Note

This model treats the bend as an equivalent straight waveguide and does not account for: - Mode coupling between bend eigenmodes - Radiation losses due to bending - Polarization effects in bends - Non-linear dispersion effects

For more accurate bend modeling, consider using dedicated bend models that account for these physical effects.

Source code in src/sax/models/bends.py
@jax.jit
@validate_call
def bend(
    wl: sax.FloatArrayLike = 1.55,
    wl0: sax.FloatArrayLike = 1.55,
    neff: sax.FloatArrayLike = 2.34,
    ng: sax.FloatArrayLike = 3.4,
    length: sax.FloatArrayLike = 10.0,
    loss_dB_cm: sax.FloatArrayLike = 0.1,
) -> sax.SDict:
    """Simple waveguide bend model.

    This function models a simple bend in a waveguide by treating it as an
    equivalent straight waveguide with the same effective length and properties.
    The bend is approximated using the straight waveguide model, inheriting all
    its dispersion and loss characteristics.

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Reference wavelength in micrometers used for dispersion calculation.
            This is typically the design wavelength where neff is specified.
            Defaults to 1.55 μm.
        neff: Effective refractive index at the reference wavelength. This value
            represents the fundamental mode effective index and determines the
            phase velocity. Defaults to 2.34 (typical for silicon).
        ng: Group refractive index at the reference wavelength. Used to calculate
            chromatic dispersion: ng = neff - lambda * d(neff)/d(lambda).
            Typically ng > neff for normal dispersion. Defaults to 3.4.
        length: Physical length of the waveguide in micrometers. Determines both
            the total phase accumulation and loss. Defaults to 10.0 μm.
        loss_dB_cm: Propagation loss in dB/cm. Includes material absorption,
            scattering losses, and other loss mechanisms. Set to 0.0 for
            lossless modeling. Defaults to 0.0 dB/cm.

    Returns:
        The bend s-matrix

    Examples:
        Basic bend simulation:

    ```python
    import numpy as np
    import sax

    # Single wavelength simulation
    s_matrix = sax.models.bend(wl=1.55, length=20.0, loss_dB_cm=0.2)
    print(f"Transmission: {s_matrix[('in0', 'out0')]}")
    ```

    Multi-wavelength analysis:

    ```python
    wavelengths = np.linspace(1.5, 1.6, 101)
    s_matrices = sax.models.bend(
        wl=wavelengths, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5
    )
    transmission = np.abs(s_matrices[("in0", "out0")]) ** 2
    ```

    ```
                   o2/out0
                   |
                  /
                 /
    o1/in0 _____/
    ```

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wavelengths = np.linspace(1.5, 1.6, 101)
    s = sax.models.bend(wl=wavelengths, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5)
    transmission = np.abs(s[("o1", "o2")]) ** 2

    plt.figure()
    plt.plot(wavelengths, transmission)
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Transmission")
    ```

    Note:
        This model treats the bend as an equivalent straight waveguide and does not
        account for:
        - Mode coupling between bend eigenmodes
        - Radiation losses due to bending
        - Polarization effects in bends
        - Non-linear dispersion effects

        For more accurate bend modeling, consider using dedicated bend models that
        account for these physical effects.
    """
    return straight(
        wl=wl,
        wl0=wl0,
        neff=neff,
        ng=ng,
        length=length,
        loss_dB_cm=loss_dB_cm,
    )

copier

copier(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> SCooModel

Generate a power copying/amplification device model.

This function creates a model for optical devices that can copy or amplify input signals to multiple outputs. Unlike unitary devices, copiers can provide gain and may not conserve power, making them useful for modeling amplifiers, laser arrays, or ideal copying devices.

The copier model creates controlled coupling between inputs and outputs without the strict unitarity constraints of passive devices, allowing for flexible gain and splitting characteristics.

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
reciprocal bool

If True, the device exhibits reciprocal behavior where forward and reverse transmissions are equal. Defaults to True.

True
diagonal bool

If True, creates diagonal coupling (each input couples to only one output). If False, creates full coupling between all input-output pairs. Defaults to False.

False

Returns:

Type Description
SCooModel

A copier model

Examples:

Create a 1×4 optical amplifier/splitter:

import sax

amp_splitter = sax.models.copier(1, 4, reciprocal=False)
Si, Sj, Sx, port_map = amp_splitter(wl=1.55)
# Single input amplified and split to 4 outputs

Create a 2×2 bidirectional amplifier:

bidir_amp = sax.models.copier(2, 2, reciprocal=True, diagonal=True)
Si, Sj, Sx, port_map = bidir_amp(wl=1.55)
# Each input amplified to corresponding output

Create a broadcast network:

broadcaster = sax.models.copier(3, 6, reciprocal=False)
Si, Sj, Sx, port_map = broadcaster(wl=1.55)
# Each input broadcast to all outputs with gain

Multi-wavelength gain device:

import numpy as np

wavelengths = np.linspace(1.5, 1.6, 101)
gain_device = sax.models.copier(1, 2, diagonal=False)
Si, Sj, Sx, port_map = gain_device(wl=wavelengths)
# Wavelength-independent copying/amplification
Note

The copier model differs from unitary devices in several key ways:

  1. Power Conservation: May not conserve power (can provide gain)
  2. Unitarity: S-matrix may not be unitary (|S|² can exceed 1)
  3. Physical Realizability: Requires active elements or non-reciprocal materials

Applications include: - Optical amplifier modeling (SOAs, EDFAs, Raman) - Laser array modeling - Signal distribution networks - Quantum information processing (ideal copying devices) - Gain medium characterization

The coupling matrix construction follows similar principles to the unitary model but without SVD normalization, allowing for arbitrary transmission coefficients.

For realistic amplifier modeling, consider adding: - Wavelength-dependent gain profiles - Noise figures and ASE generation - Saturation effects - Polarization dependencies

Note: True quantum copying of unknown states is forbidden by the no-cloning theorem, but this model can represent classical optical copying or quantum copying of known states.

Source code in src/sax/models/factories.py
@validate_call
def copier(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> sax.SCooModel:
    """Generate a power copying/amplification device model.

    This function creates a model for optical devices that can copy or amplify
    input signals to multiple outputs. Unlike unitary devices, copiers can
    provide gain and may not conserve power, making them useful for modeling
    amplifiers, laser arrays, or ideal copying devices.

    The copier model creates controlled coupling between inputs and outputs
    without the strict unitarity constraints of passive devices, allowing
    for flexible gain and splitting characteristics.

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        reciprocal: If True, the device exhibits reciprocal behavior where
            forward and reverse transmissions are equal. Defaults to True.
        diagonal: If True, creates diagonal coupling (each input couples to
            only one output). If False, creates full coupling between all
            input-output pairs. Defaults to False.

    Returns:
        A copier model

    Examples:
        Create a 1×4 optical amplifier/splitter:

        ```python
        import sax

        amp_splitter = sax.models.copier(1, 4, reciprocal=False)
        Si, Sj, Sx, port_map = amp_splitter(wl=1.55)
        # Single input amplified and split to 4 outputs
        ```

        Create a 2×2 bidirectional amplifier:

        ```python
        bidir_amp = sax.models.copier(2, 2, reciprocal=True, diagonal=True)
        Si, Sj, Sx, port_map = bidir_amp(wl=1.55)
        # Each input amplified to corresponding output
        ```

        Create a broadcast network:

        ```python
        broadcaster = sax.models.copier(3, 6, reciprocal=False)
        Si, Sj, Sx, port_map = broadcaster(wl=1.55)
        # Each input broadcast to all outputs with gain
        ```

        Multi-wavelength gain device:

        ```python
        import numpy as np

        wavelengths = np.linspace(1.5, 1.6, 101)
        gain_device = sax.models.copier(1, 2, diagonal=False)
        Si, Sj, Sx, port_map = gain_device(wl=wavelengths)
        # Wavelength-independent copying/amplification
        ```

    Note:
        The copier model differs from unitary devices in several key ways:

        1. Power Conservation: May not conserve power (can provide gain)
        2. Unitarity: S-matrix may not be unitary (|S|² can exceed 1)
        3. Physical Realizability: Requires active elements or non-reciprocal materials

        Applications include:
        - Optical amplifier modeling (SOAs, EDFAs, Raman)
        - Laser array modeling
        - Signal distribution networks
        - Quantum information processing (ideal copying devices)
        - Gain medium characterization

        The coupling matrix construction follows similar principles to the
        unitary model but without SVD normalization, allowing for arbitrary
        transmission coefficients.

        For realistic amplifier modeling, consider adding:
        - Wavelength-dependent gain profiles
        - Noise figures and ASE generation
        - Saturation effects
        - Polarization dependencies

        Note: True quantum copying of unknown states is forbidden by the
        no-cloning theorem, but this model can represent classical optical
        copying or quantum copying of known states.
    """
    # let's create the squared S-matrix:
    S = jnp.zeros((num_inputs + num_outputs, num_inputs + num_outputs), dtype=float)

    if not diagonal:
        S = S.at[:num_inputs, num_inputs:].set(1)
    else:
        r = jnp.arange(
            num_inputs,
            dtype=int,
        )  # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs!
        S = S.at[r, num_inputs + r].set(1)

    if reciprocal:
        if not diagonal:
            S = S.at[num_inputs:, :num_inputs].set(1)
        else:
            # reciprocal only works if num_inputs == num_outputs!
            r = jnp.arange(num_inputs, dtype=int)  # == range(num_outputs)
            S = S.at[num_inputs + r, r].set(1)

    # let's convert it in SCOO format:
    Si, Sj = jnp.where(jnp.sqrt(sax.EPS) < S)
    Sx = S[Si, Sj]

    # the last missing piece is a port map:
    p = sax.PortNamer(num_inputs, num_outputs)
    pm = {
        **{p[i]: i for i in range(num_inputs)},
        **{p[i + num_inputs]: i + num_inputs for i in range(num_outputs)},
    }

    @validate_call
    def func(wl: sax.FloatArrayLike = 1.5) -> sax.SCoo:
        wl_ = jnp.asarray(wl)
        Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))
        return Si, Sj, Sx_, pm

    func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
    func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
    return jax.jit(func)

coupler

coupler(
    *,
    wl: FloatArrayLike = 1.55,
    wl0: FloatArrayLike = 1.55,
    length: FloatArrayLike = 0.0,
    coupling0: FloatArrayLike = 0.2,
    dk1: FloatArrayLike = 1.2435,
    dk2: FloatArrayLike = 5.3022,
    dn: FloatArrayLike = 0.02,
    dn1: FloatArrayLike = 0.1169,
    dn2: FloatArrayLike = 0.4821,
) -> SDict

Dispersive directional coupler model.

This function models a realistic directional coupler with wavelength-dependent coupling and chromatic dispersion effects. The model includes both the coupling region dispersion and bend-induced coupling contributions.

The model is based on coupled-mode theory and includes: - Wavelength-dependent coupling coefficient - Effective index difference between even/odd supermodes - Both first and second-order dispersion terms - Bend region coupling contributions

equations adapted from photontorch. https://github.com/flaport/photontorch/blob/master/photontorch/components/directionalcouplers.py

kappa = coupling0 + coupling

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

1.55
wl0 FloatArrayLike

Reference wavelength in micrometers for dispersion expansion. Typically the center design wavelength. Defaults to 1.55 μm.

1.55
length FloatArrayLike

Coupling length in micrometers. This is the length over which the two waveguides are in close proximity. Defaults to 0.0 μm.

0.0
coupling0 FloatArrayLike

Base coupling coefficient at the reference wavelength from bend regions or other coupling mechanisms. Obtained from FDTD simulations or measurements. Defaults to 0.2.

0.2
dk1 FloatArrayLike

First-order derivative of coupling coefficient with respect to wavelength (∂κ/∂λ). Units: μm⁻¹. Defaults to 1.2435.

1.2435
dk2 FloatArrayLike

Second-order derivative of coupling coefficient with respect to wavelength (∂²κ/∂λ²). Units: μm⁻². Defaults to 5.3022.

5.3022
dn FloatArrayLike

Effective index difference between even and odd supermodes at the reference wavelength. Determines beating length. Defaults to 0.02.

0.02
dn1 FloatArrayLike

First-order derivative of effective index difference with respect to wavelength (∂Δn/∂λ). Units: μm⁻¹. Defaults to 0.1169.

0.1169
dn2 FloatArrayLike

Second-order derivative of effective index difference with respect to wavelength (∂²Δn/∂λ²). Units: μm⁻². Defaults to 0.4821.

0.4821

Returns:

Type Description
SDict

The coupler s-matrix

Examples:

Basic dispersive coupler:

import sax
import numpy as np

s_matrix = sax.models.coupler(
    wl=1.55,
    length=10.0,  # 10 μm coupling length
    coupling0=0.1,
    dn=0.015
)
bar_transmission = abs(s_matrix[('in0', 'out0')])**2
cross_transmission = abs(s_matrix[('in0', 'out1')])**2

Wavelength sweep analysis:

wavelengths = np.linspace(1.5, 1.6, 101)
s_matrices = sax.models.coupler(
    wl=wavelengths,
    length=20.0,
    coupling0=0.2,
    dn=0.02,
    dn1=0.1  # Include dispersion
)
bar_power = np.abs(s_matrices[('in0', 'out0')])**2
cross_power = np.abs(s_matrices[('in0', 'out1')])**2

Design for specific coupling:

# Design for 3dB coupling at 1.55 μm
target_coupling = 0.5
# Adjust length and coupling0 to achieve target
s_matrix = sax.models.coupler(
    wl=1.55,
    length=15.7,  # Calculated for π/2 phase
    coupling0=0.0,
    dn=0.02
)
    in1/o2 -----                      ----- out1/o3
                \ ◀-----length-----▶ /
                --------------------
    coupling0/2      coupling      coupling0/2
                --------------------
                /                    \
    in0/o1 -----                      ----- out0/o4
# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wavelengths = np.linspace(1.5, 1.6, 101)
s = sax.models.coupler(
    wl=wavelengths,
    length=15.7,
    coupling0=0.0,
    dn=0.02,
)
bar_power = np.abs(s[("o1", "o4")]) ** 2
cross_power = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wavelengths, bar_power, label="Bar")
plt.plot(wavelengths, cross_power, label="Cross")
plt.xlabel("Wavelength (μm)")
plt.ylabel("Power")
plt.legend()
Note

The coupling strength follows the formula: κ_total = κ₀ + κ₁ + κ_length

Where: - κ₀ = coupling0 + dk1(λ-λ₀) + 0.5dk2(λ-λ₀)² - κ₁ = πΔn(λ)/λ - κ_length represents the distributed coupling over the interaction length

This model assumes: - Weak coupling regime (small coupling per unit length) - Linear dispersion approximation for small wavelength deviations - Symmetric coupler geometry - No higher-order modes

Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def coupler(
    *,
    wl: sax.FloatArrayLike = 1.55,
    wl0: sax.FloatArrayLike = 1.55,
    length: sax.FloatArrayLike = 0.0,
    coupling0: sax.FloatArrayLike = 0.2,
    dk1: sax.FloatArrayLike = 1.2435,
    dk2: sax.FloatArrayLike = 5.3022,
    dn: sax.FloatArrayLike = 0.02,
    dn1: sax.FloatArrayLike = 0.1169,
    dn2: sax.FloatArrayLike = 0.4821,
) -> sax.SDict:
    r"""Dispersive directional coupler model.

    This function models a realistic directional coupler with wavelength-dependent
    coupling and chromatic dispersion effects. The model includes both the
    coupling region dispersion and bend-induced coupling contributions.

    The model is based on coupled-mode theory and includes:
    - Wavelength-dependent coupling coefficient
    - Effective index difference between even/odd supermodes
    - Both first and second-order dispersion terms
    - Bend region coupling contributions

    equations adapted from photontorch.
    https://github.com/flaport/photontorch/blob/master/photontorch/components/directionalcouplers.py

    kappa = coupling0 + coupling

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Reference wavelength in micrometers for dispersion expansion.
            Typically the center design wavelength. Defaults to 1.55 μm.
        length: Coupling length in micrometers. This is the length over which
            the two waveguides are in close proximity. Defaults to 0.0 μm.
        coupling0: Base coupling coefficient at the reference wavelength from
            bend regions or other coupling mechanisms. Obtained from FDTD
            simulations or measurements. Defaults to 0.2.
        dk1: First-order derivative of coupling coefficient with respect to
            wavelength (∂κ/∂λ). Units: μm⁻¹. Defaults to 1.2435.
        dk2: Second-order derivative of coupling coefficient with respect to
            wavelength (∂²κ/∂λ²). Units: μm⁻². Defaults to 5.3022.
        dn: Effective index difference between even and odd supermodes at
            the reference wavelength. Determines beating length. Defaults to 0.02.
        dn1: First-order derivative of effective index difference with respect
            to wavelength (∂Δn/∂λ). Units: μm⁻¹. Defaults to 0.1169.
        dn2: Second-order derivative of effective index difference with respect
            to wavelength (∂²Δn/∂λ²). Units: μm⁻². Defaults to 0.4821.

    Returns:
        The coupler s-matrix

    Examples:
        Basic dispersive coupler:

        ```python
        import sax
        import numpy as np

        s_matrix = sax.models.coupler(
            wl=1.55,
            length=10.0,  # 10 μm coupling length
            coupling0=0.1,
            dn=0.015
        )
        bar_transmission = abs(s_matrix[('in0', 'out0')])**2
        cross_transmission = abs(s_matrix[('in0', 'out1')])**2
        ```

        Wavelength sweep analysis:

        ```python
        wavelengths = np.linspace(1.5, 1.6, 101)
        s_matrices = sax.models.coupler(
            wl=wavelengths,
            length=20.0,
            coupling0=0.2,
            dn=0.02,
            dn1=0.1  # Include dispersion
        )
        bar_power = np.abs(s_matrices[('in0', 'out0')])**2
        cross_power = np.abs(s_matrices[('in0', 'out1')])**2
        ```

        Design for specific coupling:

        ```python
        # Design for 3dB coupling at 1.55 μm
        target_coupling = 0.5
        # Adjust length and coupling0 to achieve target
        s_matrix = sax.models.coupler(
            wl=1.55,
            length=15.7,  # Calculated for π/2 phase
            coupling0=0.0,
            dn=0.02
        )
        ```

    ```
        in1/o2 -----                      ----- out1/o3
                    \ ◀-----length-----▶ /
                    --------------------
        coupling0/2      coupling      coupling0/2
                    --------------------
                    /                    \
        in0/o1 -----                      ----- out0/o4

    ```

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wavelengths = np.linspace(1.5, 1.6, 101)
    s = sax.models.coupler(
        wl=wavelengths,
        length=15.7,
        coupling0=0.0,
        dn=0.02,
    )
    bar_power = np.abs(s[("o1", "o4")]) ** 2
    cross_power = np.abs(s[("o1", "o3")]) ** 2
    plt.figure()
    plt.plot(wavelengths, bar_power, label="Bar")
    plt.plot(wavelengths, cross_power, label="Cross")
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Power")
    plt.legend()
    ```

    Note:
        The coupling strength follows the formula:
        κ_total = κ₀ + κ₁ + κ_length

        Where:
        - κ₀ = coupling0 + dk1*(λ-λ₀) + 0.5*dk2*(λ-λ₀)²
        - κ₁ = π*Δn(λ)/λ
        - κ_length represents the distributed coupling over the interaction length

        This model assumes:
        - Weak coupling regime (small coupling per unit length)
        - Linear dispersion approximation for small wavelength deviations
        - Symmetric coupler geometry
        - No higher-order modes
    """
    dwl = wl - wl0
    dn = dn + dn1 * dwl + 0.5 * dn2 * dwl**2
    kappa0 = coupling0 + dk1 * dwl + 0.5 * dk2 * dwl**2
    kappa1 = jnp.pi * dn / wl

    tau = jnp.cos(kappa0 + kappa1 * length)
    kappa = -jnp.sin(kappa0 + kappa1 * length)
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): 1j * kappa,
            (p.in1, p.out0): 1j * kappa,
            (p.in1, p.out1): tau,
        }
    )

coupler_ideal

coupler_ideal(*, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 2x2 directional coupler model.

This function models an ideal 2x2 directional coupler with perfect coupling efficiency and no loss. The coupler divides input power between two output ports based on the coupling coefficient, with a 90-degree phase shift between cross-coupled terms to maintain unitarity.

The coupler has four ports arranged as: - Input ports: in0, in1 (left side) - Output ports: out0, out1 (right side)

Power coupling behavior: - Fraction (1-coupling) transmits straight through (bar state) - Fraction coupling is cross-coupled between arms - Total power is conserved across all ports

Parameters:

Name Type Description Default
coupling FloatArrayLike

Power coupling coefficient between 0 and 1. Determines the fraction of power that couples between the two waveguide arms. - coupling=0.0: No coupling (straight transmission) - coupling=0.5: 3dB coupler (equal bar/cross transmission) - coupling=1.0: Complete cross-coupling Defaults to 0.5.

0.5

Returns:

Type Description
SDict

The coupler s-matrix

Examples:

3dB coupler (equal splitting):

import sax

s_matrix = sax.models.coupler_ideal(coupling=0.5)
# Bar transmission (straight through)
bar_power = abs(s_matrix[("in0", "out0")]) ** 2
# Cross transmission (between arms)
cross_power = abs(s_matrix[("in0", "out1")]) ** 2
print(f"Bar power: {bar_power:.3f}")  # Should be 0.5
print(f"Cross power: {cross_power:.3f}")  # Should be 0.5

Asymmetric coupler:

s_matrix = sax.models.coupler_ideal(coupling=0.1)  # 10% coupling
bar_power = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.9
cross_power = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.1

Phase relationship verification:

s_matrix = sax.models.coupler_ideal(coupling=0.5)
bar_phase = jnp.angle(s_matrix[("in0", "out0")])
cross_phase = jnp.angle(s_matrix[("in0", "out1")])
phase_diff = cross_phase - bar_phase
print(f"Phase difference: {phase_diff:.3f} rad")  # Should be π/2
Note

This is an idealized lossless model that assumes: - Perfect power conservation (unitary S-matrix) - No reflection at any port - Wavelength-independent behavior - Symmetric coupling between both directions

The 90-degree phase shift in cross-coupled terms (1j factor) is required for maintaining S-matrix unitarity and represents the physical phase relationship in evanescent coupling.

Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def coupler_ideal(*, coupling: sax.FloatArrayLike = 0.5) -> sax.SDict:
    r"""Ideal 2x2 directional coupler model.

    This function models an ideal 2x2 directional coupler with perfect coupling
    efficiency and no loss. The coupler divides input power between two output
    ports based on the coupling coefficient, with a 90-degree phase shift
    between cross-coupled terms to maintain unitarity.

    The coupler has four ports arranged as:
    - Input ports: in0, in1 (left side)
    - Output ports: out0, out1 (right side)

    Power coupling behavior:
    - Fraction (1-coupling) transmits straight through (bar state)
    - Fraction coupling is cross-coupled between arms
    - Total power is conserved across all ports

    Args:
        coupling: Power coupling coefficient between 0 and 1. Determines the
            fraction of power that couples between the two waveguide arms.
            - coupling=0.0: No coupling (straight transmission)
            - coupling=0.5: 3dB coupler (equal bar/cross transmission)
            - coupling=1.0: Complete cross-coupling
            Defaults to 0.5.

    Returns:
        The coupler s-matrix

    Examples:
        3dB coupler (equal splitting):

        ```python
        import sax

        s_matrix = sax.models.coupler_ideal(coupling=0.5)
        # Bar transmission (straight through)
        bar_power = abs(s_matrix[("in0", "out0")]) ** 2
        # Cross transmission (between arms)
        cross_power = abs(s_matrix[("in0", "out1")]) ** 2
        print(f"Bar power: {bar_power:.3f}")  # Should be 0.5
        print(f"Cross power: {cross_power:.3f}")  # Should be 0.5
        ```

        Asymmetric coupler:

        ```python
        s_matrix = sax.models.coupler_ideal(coupling=0.1)  # 10% coupling
        bar_power = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.9
        cross_power = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.1
        ```

        Phase relationship verification:

        ```python
        s_matrix = sax.models.coupler_ideal(coupling=0.5)
        bar_phase = jnp.angle(s_matrix[("in0", "out0")])
        cross_phase = jnp.angle(s_matrix[("in0", "out1")])
        phase_diff = cross_phase - bar_phase
        print(f"Phase difference: {phase_diff:.3f} rad")  # Should be π/2
        ```

    Note:
        This is an idealized lossless model that assumes:
        - Perfect power conservation (unitary S-matrix)
        - No reflection at any port
        - Wavelength-independent behavior
        - Symmetric coupling between both directions

        The 90-degree phase shift in cross-coupled terms (1j factor) is required
        for maintaining S-matrix unitarity and represents the physical phase
        relationship in evanescent coupling.

    """
    kappa = jnp.asarray(coupling**0.5)
    tau = jnp.asarray((1 - coupling) ** 0.5)
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): 1j * kappa,
            (p.in1, p.out0): 1j * kappa,
            (p.in1, p.out1): tau,
        },
    )

crossing_ideal

crossing_ideal(wl: FloatArrayLike = 1.5) -> SDict

Ideal waveguide crossing model.

This function models an ideal 4-port waveguide crossing where light from each input port is transmitted to the corresponding output port on the opposite side with no loss, reflection, or crosstalk. This represents the theoretical limit of a perfect crossing design.

The crossing has four ports arranged in a cross pattern: - Port o1 (left input) connects to Port o3 (right output) - Port o2 (bottom input) connects to Port o4 (top output) - No coupling between orthogonal waveguides (no crosstalk) - Perfect transmission with unit amplitude

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. The wavelength dependence is minimal for this ideal model. Defaults to 1.5 μm.

1.5

Returns:

Type Description
SDict

The crossing s-matrix

Examples:

Basic crossing simulation:

import sax

# Single wavelength
s_matrix = sax.models.crossing_ideal(wl=1.55)
print(f"Horizontal transmission: {s_matrix[('o1', 'o3')]}")
print(f"Vertical transmission: {s_matrix[('o2', 'o4')]}")

Multi-wavelength analysis:

import numpy as np

wavelengths = np.linspace(1.5, 1.6, 101)
s_matrices = sax.models.crossing_ideal(wl=wavelengths)
# All transmissions should be unity
transmission = np.abs(s_matrices[("o1", "o3")]) ** 2
Note

This is an idealized model that assumes: - Perfect phase matching between intersecting waveguides - No scattering losses at the intersection - No reflection at the crossing interfaces - Complete isolation between orthogonal paths

Real crossings typically have finite insertion loss (0.1-1 dB), some crosstalk (-20 to -40 dB), and wavelength-dependent performance. For more realistic modeling, consider using fabrication-specific crossing models with measured parameters.

Source code in src/sax/models/crossings.py
@jax.jit
@validate_call
def crossing_ideal(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
    """Ideal waveguide crossing model.

    This function models an ideal 4-port waveguide crossing where light from each
    input port is transmitted to the corresponding output port on the opposite side
    with no loss, reflection, or crosstalk. This represents the theoretical limit
    of a perfect crossing design.

    The crossing has four ports arranged in a cross pattern:
    - Port o1 (left input) connects to Port o3 (right output)
    - Port o2 (bottom input) connects to Port o4 (top output)
    - No coupling between orthogonal waveguides (no crosstalk)
    - Perfect transmission with unit amplitude

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. The wavelength dependence is minimal
            for this ideal model. Defaults to 1.5 μm.

    Returns:
        The crossing s-matrix

    Examples:
        Basic crossing simulation:

        ```python
        import sax

        # Single wavelength
        s_matrix = sax.models.crossing_ideal(wl=1.55)
        print(f"Horizontal transmission: {s_matrix[('o1', 'o3')]}")
        print(f"Vertical transmission: {s_matrix[('o2', 'o4')]}")
        ```

        Multi-wavelength analysis:

        ```python
        import numpy as np

        wavelengths = np.linspace(1.5, 1.6, 101)
        s_matrices = sax.models.crossing_ideal(wl=wavelengths)
        # All transmissions should be unity
        transmission = np.abs(s_matrices[("o1", "o3")]) ** 2
        ```

    Note:
        This is an idealized model that assumes:
        - Perfect phase matching between intersecting waveguides
        - No scattering losses at the intersection
        - No reflection at the crossing interfaces
        - Complete isolation between orthogonal paths

        Real crossings typically have finite insertion loss (0.1-1 dB),
        some crosstalk (-20 to -40 dB), and wavelength-dependent performance.
        For more realistic modeling, consider using fabrication-specific
        crossing models with measured parameters.
    """
    one = jnp.ones_like(jnp.asarray(wl))
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o3): one,
            (p.o2, p.o4): one,
        }
    )

grating_coupler

grating_coupler(
    *,
    wl: FloatArrayLike = 1.55,
    wl0: FloatArrayLike = 1.55,
    loss: FloatArrayLike = 0.0,
    reflection: FloatArrayLike = 0.0,
    reflection_fiber: FloatArrayLike = 0.0,
    bandwidth: FloatArrayLike = 0.04,
) -> SDict

Grating coupler model for fiber-chip coupling.

This function models a grating coupler used to couple light between an optical fiber and an on-chip waveguide. The model includes wavelength- dependent transmission with a Gaussian spectral response, insertion loss, and reflection effects from both the waveguide and fiber sides.

equation adapted from photontorch grating coupler https://github.com/flaport/photontorch/blob/master/photontorch/components/gratingcouplers.py

The grating coupler provides: - Vertical coupling between fiber and chip - Wavelength-selective transmission - Bidirectional operation - Reflection handling for both interfaces

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for spectral analysis. Defaults to 1.55 μm.

1.55
wl0 FloatArrayLike

Center wavelength in micrometers where peak transmission occurs. This is the design wavelength of the grating. Defaults to 1.55 μm.

1.55
loss FloatArrayLike

Insertion loss in dB at the center wavelength. Includes coupling efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.

0.0
reflection FloatArrayLike

Reflection coefficient from the waveguide side (chip interface). Represents reflections back into the waveguide from grating discontinuities. Range: 0 to 1. Defaults to 0.0.

0.0
reflection_fiber FloatArrayLike

Reflection coefficient from the fiber side (top interface). Represents reflections back toward the fiber from the grating surface. Range: 0 to 1. Defaults to 0.0.

0.0
bandwidth FloatArrayLike

3dB bandwidth in micrometers. Determines the spectral width of the Gaussian transmission profile. Typical values: 20-50 nm. Defaults to 40e-3 μm (40 nm).

0.04

Returns:

Type Description
SDict

The grating coupler s-matrix

Examples:

Basic grating coupler:

import sax
import numpy as np

s_matrix = sax.models.grating_coupler(
    wl=1.55,
    wl0=1.55,
    loss=3.0,  # 3 dB insertion loss
    bandwidth=0.035,  # 35 nm bandwidth
)
coupling_efficiency = abs(s_matrix[("in0", "out0")]) ** 2
print(f"Coupling efficiency: {coupling_efficiency:.3f}")

Spectral response analysis:

wavelengths = np.linspace(1.5, 1.6, 101)
s_matrices = sax.models.grating_coupler(
    wl=wavelengths, wl0=1.55, loss=4.0, bandwidth=0.040
)
transmission = np.abs(s_matrices[("in0", "out0")]) ** 2
# Plot spectral response

Grating with reflections:

s_matrix = sax.models.grating_coupler(
    wl=1.55,
    loss=3.5,
    reflection=0.05,  # 5% waveguide reflection
    reflection_fiber=0.02,  # 2% fiber reflection
)
waveguide_reflection = abs(s_matrix[("in0", "in0")]) ** 2
fiber_reflection = abs(s_matrix[("out0", "out0")]) ** 2
Note

The transmission profile follows a Gaussian shape: T(λ) = T₀ * exp(-((λ-λ₀)/σ)²)

Where σ = bandwidth / (2√(2ln(2))) converts FWHM to Gaussian width.

This model assumes: - Gaussian spectral response (typical for uniform gratings) - Wavelength-independent loss coefficient - Linear polarization (TE or TM) - Single-mode fiber coupling - No higher-order diffraction effects

Real grating couplers may exhibit: - Non-Gaussian spectral shapes - Polarization dependence - Temperature sensitivity - Higher-order grating resonances - Angular sensitivity to fiber positioning

For more accurate modeling, consider using measured spectral data or finite-element simulation results.

                  out0/o2
                   fiber
               /  /  /  /
              /  /  /  /

            _|-|_|-|_|-|__
    in0/o1                |
             _____________|
# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wavelengths = np.linspace(1.5, 1.6, 101)
s = sax.models.grating_coupler(
    wl=wavelengths,
    loss=3.0,
    bandwidth=0.035,
)
plt.figure()
plt.plot(wavelengths, np.abs(s[("o1", "o2")]) ** 2, label="Transmission")
plt.xlabel("Wavelength (μm)")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def grating_coupler(
    *,
    wl: sax.FloatArrayLike = 1.55,
    wl0: sax.FloatArrayLike = 1.55,
    loss: sax.FloatArrayLike = 0.0,
    reflection: sax.FloatArrayLike = 0.0,
    reflection_fiber: sax.FloatArrayLike = 0.0,
    bandwidth: sax.FloatArrayLike = 40e-3,
) -> sax.SDict:
    """Grating coupler model for fiber-chip coupling.

    This function models a grating coupler used to couple light between an
    optical fiber and an on-chip waveguide. The model includes wavelength-
    dependent transmission with a Gaussian spectral response, insertion loss,
    and reflection effects from both the waveguide and fiber sides.

    equation adapted from photontorch grating coupler
    https://github.com/flaport/photontorch/blob/master/photontorch/components/gratingcouplers.py

    The grating coupler provides:
    - Vertical coupling between fiber and chip
    - Wavelength-selective transmission
    - Bidirectional operation
    - Reflection handling for both interfaces

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            spectral analysis. Defaults to 1.55 μm.
        wl0: Center wavelength in micrometers where peak transmission occurs.
            This is the design wavelength of the grating. Defaults to 1.55 μm.
        loss: Insertion loss in dB at the center wavelength. Includes coupling
            efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.
        reflection: Reflection coefficient from the waveguide side (chip interface).
            Represents reflections back into the waveguide from grating discontinuities.
            Range: 0 to 1. Defaults to 0.0.
        reflection_fiber: Reflection coefficient from the fiber side (top interface).
            Represents reflections back toward the fiber from the grating surface.
            Range: 0 to 1. Defaults to 0.0.
        bandwidth: 3dB bandwidth in micrometers. Determines the spectral width
            of the Gaussian transmission profile. Typical values: 20-50 nm.
            Defaults to 40e-3 μm (40 nm).

    Returns:
        The grating coupler s-matrix

    Examples:
        Basic grating coupler:

        ```python
        import sax
        import numpy as np

        s_matrix = sax.models.grating_coupler(
            wl=1.55,
            wl0=1.55,
            loss=3.0,  # 3 dB insertion loss
            bandwidth=0.035,  # 35 nm bandwidth
        )
        coupling_efficiency = abs(s_matrix[("in0", "out0")]) ** 2
        print(f"Coupling efficiency: {coupling_efficiency:.3f}")
        ```

        Spectral response analysis:

        ```python
        wavelengths = np.linspace(1.5, 1.6, 101)
        s_matrices = sax.models.grating_coupler(
            wl=wavelengths, wl0=1.55, loss=4.0, bandwidth=0.040
        )
        transmission = np.abs(s_matrices[("in0", "out0")]) ** 2
        # Plot spectral response
        ```

        Grating with reflections:

        ```python
        s_matrix = sax.models.grating_coupler(
            wl=1.55,
            loss=3.5,
            reflection=0.05,  # 5% waveguide reflection
            reflection_fiber=0.02,  # 2% fiber reflection
        )
        waveguide_reflection = abs(s_matrix[("in0", "in0")]) ** 2
        fiber_reflection = abs(s_matrix[("out0", "out0")]) ** 2
        ```

    Note:
        The transmission profile follows a Gaussian shape:
        T(λ) = T₀ * exp(-((λ-λ₀)/σ)²)

        Where σ = bandwidth / (2*√(2*ln(2))) converts FWHM to Gaussian width.

        This model assumes:
        - Gaussian spectral response (typical for uniform gratings)
        - Wavelength-independent loss coefficient
        - Linear polarization (TE or TM)
        - Single-mode fiber coupling
        - No higher-order diffraction effects

        Real grating couplers may exhibit:
        - Non-Gaussian spectral shapes
        - Polarization dependence
        - Temperature sensitivity
        - Higher-order grating resonances
        - Angular sensitivity to fiber positioning

        For more accurate modeling, consider using measured spectral data
        or finite-element simulation results.

    ```
                      out0/o2
                       fiber
                   /  /  /  /
                  /  /  /  /

                _|-|_|-|_|-|__
        in0/o1                |
                 _____________|
    ```

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wavelengths = np.linspace(1.5, 1.6, 101)
    s = sax.models.grating_coupler(
        wl=wavelengths,
        loss=3.0,
        bandwidth=0.035,
    )
    plt.figure()
    plt.plot(wavelengths, np.abs(s[("o1", "o2")]) ** 2, label="Transmission")
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Power")
    plt.legend()
    ```
    """
    one = jnp.ones_like(wl)
    reflection = jnp.asarray(reflection) * one
    reflection_fiber = jnp.asarray(reflection_fiber) * one
    amplitude = jnp.asarray(10 ** (-loss / 20))
    sigma = jnp.asarray(bandwidth / (2 * jnp.sqrt(2 * jnp.log(2))))
    transmission = jnp.asarray(amplitude * jnp.exp(-((wl - wl0) ** 2) / (2 * sigma**2)))
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.in0): reflection,
            (p.in0, p.out0): transmission,
            (p.out0, p.in0): transmission,
            (p.out0, p.out0): reflection_fiber,
        }
    )

mmi1x2

mmi1x2(
    wl: FloatArrayLike = 1.55,
    wl0: FloatArrayLike = 1.55,
    fwhm: FloatArrayLike = 0.2,
    loss_dB: FloatArrayLike = 0.3,
) -> SDict

Realistic 1x2 MMI splitter model with dispersion and loss.

This function models a realistic 1x2 multimode interference (MMI) splitter with wavelength-dependent transmission, insertion loss, and finite bandwidth. The model uses a Gaussian spectral response to capture the typical behavior of fabricated MMI devices.

The MMI splitter divides input power equally between two output ports at the design wavelength but exhibits wavelength-dependent imbalance and rolloff away from the center wavelength.

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

1.55
wl0 FloatArrayLike

Center wavelength in micrometers where optimal splitting occurs. This is the design wavelength of the MMI device. Defaults to 1.55 μm.

1.55
fwhm FloatArrayLike

Full width at half maximum bandwidth in micrometers. Determines the spectral width over which the MMI maintains good performance. Typical values: 0.1-0.3 μm. Defaults to 0.2 μm.

0.2
loss_dB FloatArrayLike

Insertion loss in dB at the center wavelength. Includes scattering losses, mode conversion losses, and fabrication imperfections. Defaults to 0.3 dB.

0.3

Returns:

Type Description
SDict

S-matrix dictionary representing the dispersive MMI splitter behavior.

Examples:

Basic MMI splitter:

import sax
import numpy as np

s_matrix = sax.models.mmi1x2(wl=1.55, fwhm=0.15, loss_dB=0.5)
power_out0 = abs(s_matrix[("o1", "o2")]) ** 2
power_out1 = abs(s_matrix[("o1", "o3")]) ** 2
total_power = power_out0 + power_out1
insertion_loss_dB = -10 * np.log10(total_power)

Spectral analysis:

wavelengths = np.linspace(1.4, 1.7, 301)
s_matrices = sax.models.mmi1x2(wl=wavelengths, wl0=1.55, fwhm=0.2, loss_dB=0.4)
transmission = np.abs(s_matrices[("o1", "o2")]) ** 2
# Analyze spectral response and bandwidth

Bandwidth optimization:

# Design for specific bandwidth
target_bandwidth = 0.1  # 100 nm
s_matrix = sax.models.mmi1x2(wl=1.55, fwhm=target_bandwidth, loss_dB=0.3)
Note

The spectral response follows a Gaussian profile in the frequency domain: T(λ) = T₀ * exp(-((f-f₀)/σ_f)²)

Where f = 1/λ is the optical frequency and σ_f is related to the FWHM.

This model assumes: - Equal splitting at the center wavelength - Gaussian spectral response (typical for well-designed MMIs) - Wavelength-independent loss coefficient - No reflection at input/output interfaces

Real MMI devices may exhibit: - Wavelength-dependent splitting ratio imbalance - Higher-order spectral features - Polarization dependence - Temperature sensitivity - Phase errors between outputs

           length_mmi
            <------>
            ________
           |        |
           |         \__
           |          __  o2
        __/          /_ _ _ _
     o1 __          | _ _ _ _| gap_mmi
          \          \__
           |          __  o3
           |         /
           |________|

         <->
    length_taper
# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wavelengths = np.linspace(1.5, 1.6, 101)
s = sax.models.mmi1x2(wl=wavelengths, fwhm=0.15, loss_dB=0.5)
transmission_o2 = np.abs(s[("o1", "o2")]) ** 2
transmission_o3 = np.abs(s[("o1", "o3")]) ** 2
plt.plot(wavelengths, transmission_o2, label="Output 1")
plt.plot(wavelengths, transmission_o3, label="Output 2")
plt.xlabel("Wavelength (μm)")
plt.ylabel("Transmission")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi1x2(
    wl: sax.FloatArrayLike = 1.55,
    wl0: sax.FloatArrayLike = 1.55,
    fwhm: sax.FloatArrayLike = 0.2,
    loss_dB: sax.FloatArrayLike = 0.3,
) -> sax.SDict:
    r"""Realistic 1x2 MMI splitter model with dispersion and loss.

    This function models a realistic 1x2 multimode interference (MMI) splitter
    with wavelength-dependent transmission, insertion loss, and finite bandwidth.
    The model uses a Gaussian spectral response to capture the typical behavior
    of fabricated MMI devices.

    The MMI splitter divides input power equally between two output ports at
    the design wavelength but exhibits wavelength-dependent imbalance and
    rolloff away from the center wavelength.

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Center wavelength in micrometers where optimal splitting occurs.
            This is the design wavelength of the MMI device. Defaults to 1.55 μm.
        fwhm: Full width at half maximum bandwidth in micrometers. Determines
            the spectral width over which the MMI maintains good performance.
            Typical values: 0.1-0.3 μm. Defaults to 0.2 μm.
        loss_dB: Insertion loss in dB at the center wavelength. Includes
            scattering losses, mode conversion losses, and fabrication
            imperfections. Defaults to 0.3 dB.

    Returns:
        S-matrix dictionary representing the dispersive MMI splitter behavior.

    Examples:
        Basic MMI splitter:

        ```python
        import sax
        import numpy as np

        s_matrix = sax.models.mmi1x2(wl=1.55, fwhm=0.15, loss_dB=0.5)
        power_out0 = abs(s_matrix[("o1", "o2")]) ** 2
        power_out1 = abs(s_matrix[("o1", "o3")]) ** 2
        total_power = power_out0 + power_out1
        insertion_loss_dB = -10 * np.log10(total_power)
        ```

        Spectral analysis:

        ```python
        wavelengths = np.linspace(1.4, 1.7, 301)
        s_matrices = sax.models.mmi1x2(wl=wavelengths, wl0=1.55, fwhm=0.2, loss_dB=0.4)
        transmission = np.abs(s_matrices[("o1", "o2")]) ** 2
        # Analyze spectral response and bandwidth
        ```

        Bandwidth optimization:

        ```python
        # Design for specific bandwidth
        target_bandwidth = 0.1  # 100 nm
        s_matrix = sax.models.mmi1x2(wl=1.55, fwhm=target_bandwidth, loss_dB=0.3)
        ```

    Note:
        The spectral response follows a Gaussian profile in the frequency domain:
        T(λ) = T₀ * exp(-((f-f₀)/σ_f)²)

        Where f = 1/λ is the optical frequency and σ_f is related to the FWHM.

        This model assumes:
        - Equal splitting at the center wavelength
        - Gaussian spectral response (typical for well-designed MMIs)
        - Wavelength-independent loss coefficient
        - No reflection at input/output interfaces

        Real MMI devices may exhibit:
        - Wavelength-dependent splitting ratio imbalance
        - Higher-order spectral features
        - Polarization dependence
        - Temperature sensitivity
        - Phase errors between outputs

    ```
               length_mmi
                <------>
                ________
               |        |
               |         \__
               |          __  o2
            __/          /_ _ _ _
         o1 __          | _ _ _ _| gap_mmi
              \          \__
               |          __  o3
               |         /
               |________|

             <->
        length_taper
    ```

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wavelengths = np.linspace(1.5, 1.6, 101)
    s = sax.models.mmi1x2(wl=wavelengths, fwhm=0.15, loss_dB=0.5)
    transmission_o2 = np.abs(s[("o1", "o2")]) ** 2
    transmission_o3 = np.abs(s[("o1", "o3")]) ** 2
    plt.plot(wavelengths, transmission_o2, label="Output 1")
    plt.plot(wavelengths, transmission_o3, label="Output 2")
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Transmission")
    plt.legend()
    ```

    """
    thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB) / 2**0.5

    p = sax.PortNamer(1, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o2): thru,
            (p.o1, p.o3): thru,
        }
    )

mmi1x2_ideal

mmi1x2_ideal(*, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 1x2 multimode interference (MMI) splitter model.

This function models an ideal 1x2 MMI splitter that divides input power between two output ports. The model is implemented as an ideal splitter and represents the theoretical limit of MMI splitter performance.

MMI devices use multimode interference in a wide waveguide to achieve power splitting through self-imaging effects. This ideal model assumes perfect operation without considering the detailed physics of multimode propagation.

Parameters:

Name Type Description Default
coupling FloatArrayLike

Power coupling ratio between 0 and 1. Determines the fraction of input power that couples to the second output port (out1). - coupling=0.0: All power to out0 - coupling=0.5: Equal power splitting (3dB splitter) - coupling=1.0: All power to out1 Defaults to 0.5 for equal splitting.

0.5

Returns:

Type Description
SDict

S-matrix dictionary representing the ideal MMI splitter behavior.

Examples:

Equal power MMI splitter:

import sax

s_matrix = sax.models.mmi1x2_ideal(coupling=0.5)
power_out0 = abs(s_matrix[("in0", "out0")]) ** 2
power_out1 = abs(s_matrix[("in0", "out1")]) ** 2
print(f"Output powers: {power_out0:.3f}, {power_out1:.3f}")

Asymmetric MMI splitter:

s_matrix = sax.models.mmi1x2_ideal(coupling=0.25)  # 25% to out1
power_out0 = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.75
power_out1 = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.25
Note

This is an idealized model that assumes: - Perfect power conservation - No wavelength dependence - No reflection or insertion loss - Ideal multimode interference

Real MMI devices exhibit wavelength-dependent behavior, finite bandwidth, and may have imbalance between output ports. For more realistic modeling, use the dispersive mmi1x2() function.

Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi1x2_ideal(*, coupling: sax.FloatArrayLike = 0.5) -> sax.SDict:
    """Ideal 1x2 multimode interference (MMI) splitter model.

    This function models an ideal 1x2 MMI splitter that divides input power
    between two output ports. The model is implemented as an ideal splitter
    and represents the theoretical limit of MMI splitter performance.

    MMI devices use multimode interference in a wide waveguide to achieve
    power splitting through self-imaging effects. This ideal model assumes
    perfect operation without considering the detailed physics of multimode
    propagation.

    Args:
        coupling: Power coupling ratio between 0 and 1. Determines the fraction
            of input power that couples to the second output port (out1).
            - coupling=0.0: All power to out0
            - coupling=0.5: Equal power splitting (3dB splitter)
            - coupling=1.0: All power to out1
            Defaults to 0.5 for equal splitting.

    Returns:
        S-matrix dictionary representing the ideal MMI splitter behavior.

    Examples:
        Equal power MMI splitter:

        ```python
        import sax

        s_matrix = sax.models.mmi1x2_ideal(coupling=0.5)
        power_out0 = abs(s_matrix[("in0", "out0")]) ** 2
        power_out1 = abs(s_matrix[("in0", "out1")]) ** 2
        print(f"Output powers: {power_out0:.3f}, {power_out1:.3f}")
        ```

        Asymmetric MMI splitter:

        ```python
        s_matrix = sax.models.mmi1x2_ideal(coupling=0.25)  # 25% to out1
        power_out0 = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.75
        power_out1 = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.25
        ```

    Note:
        This is an idealized model that assumes:
        - Perfect power conservation
        - No wavelength dependence
        - No reflection or insertion loss
        - Ideal multimode interference

        Real MMI devices exhibit wavelength-dependent behavior, finite bandwidth,
        and may have imbalance between output ports. For more realistic modeling,
        use the dispersive mmi1x2() function.
    """
    return splitter_ideal(coupling=coupling)

mmi2x2

mmi2x2(
    wl: FloatArrayLike = 1.55,
    wl0: FloatArrayLike = 1.55,
    fwhm: FloatArrayLike = 0.2,
    loss_dB: FloatArrayLike = 0.3,
    shift: FloatArrayLike = 0.005,
    loss_dB_cross: FloatArrayLike | None = None,
    loss_dB_thru: FloatArrayLike | None = None,
    splitting_ratio_cross: FloatArrayLike = 0.5,
    splitting_ratio_thru: FloatArrayLike = 0.5,
) -> SDict

Realistic 2x2 MMI coupler model with dispersion and asymmetry.

This function models a realistic 2x2 multimode interference (MMI) coupler with wavelength-dependent behavior, insertion loss, and the ability to model fabrication-induced asymmetries between bar and cross ports.

The MMI coupler can function as a directional coupler, 90-degree hybrid, or beam splitter depending on the design parameters. This model includes separate control over bar and cross port characteristics.

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

1.55
wl0 FloatArrayLike

Center wavelength in micrometers for optimal coupling performance. This is the design wavelength of the MMI device. Defaults to 1.55 μm.

1.55
fwhm FloatArrayLike

Full width at half maximum bandwidth in micrometers. Determines the spectral width of good performance. Defaults to 0.2 μm.

0.2
loss_dB FloatArrayLike

Base insertion loss in dB at the center wavelength. Used for both ports unless overridden. Defaults to 0.3 dB.

0.3
shift FloatArrayLike

Wavelength shift in micrometers applied to both bar and cross responses. Models fabrication variations. Defaults to 0.005 μm.

0.005
loss_dB_cross FloatArrayLike | None

Optional separate insertion loss in dB for cross ports. If None, uses loss_dB. Allows modeling of asymmetric loss.

None
loss_dB_thru FloatArrayLike | None

Optional separate insertion loss in dB for bar (through) ports. If None, uses loss_dB. Allows modeling of asymmetric loss.

None
splitting_ratio_cross FloatArrayLike

Power splitting ratio for cross ports (0 to 1). Allows modeling of imbalanced coupling. Defaults to 0.5.

0.5
splitting_ratio_thru FloatArrayLike

Power splitting ratio for bar ports (0 to 1). Allows modeling of imbalanced transmission. Defaults to 0.5.

0.5

Returns:

Type Description
SDict

S-matrix dictionary representing the realistic MMI coupler behavior.

Examples:

Symmetric 3dB MMI coupler:

import sax
import numpy as np

s_matrix = sax.models.mmi2x2(
    wl=1.55,
    fwhm=0.15,
    loss_dB=0.4,
    splitting_ratio_cross=0.5,
    splitting_ratio_thru=0.5,
)
bar_power = abs(s_matrix[("o1", "o3")]) ** 2
cross_power = abs(s_matrix[("o1", "o4")]) ** 2

Asymmetric MMI with different bar/cross losses:

s_matrix = sax.models.mmi2x2(
    wl=1.55,
    loss_dB_thru=0.3,  # Lower bar loss
    loss_dB_cross=0.6,  # Higher cross loss
    splitting_ratio_cross=0.4,  # Imbalanced coupling
    splitting_ratio_thru=0.6,
)

Wavelength-shifted MMI (fabrication variation):

s_matrix = sax.models.mmi2x2(
    wl=1.55,
    wl0=1.55,
    shift=0.01,  # 10 nm shift from process variation
    fwhm=0.12,
)

Spectral analysis of MMI coupler:

wavelengths = np.linspace(1.45, 1.65, 201)
s_matrices = sax.models.mmi2x2(wl=wavelengths, wl0=1.55, fwhm=0.18, loss_dB=0.5)
bar_transmission = np.abs(s_matrices[("o1", "o3")]) ** 2
cross_transmission = np.abs(s_matrices[("o1", "o4")]) ** 2
Note

The cross-coupled terms include a 90-degree phase shift (1j factor) to maintain proper S-matrix properties and represent the physical phase relationship in MMI devices.

This model includes: - Separate spectral responses for bar and cross ports - Asymmetric loss modeling - Wavelength shift for process variations - Independent splitting ratio control

The model assumes: - Gaussian spectral responses - Linear amplitude relationships - Reciprocal device behavior - No higher-order spectral features

Real MMI devices may exhibit: - Non-Gaussian spectral shapes - Polarization dependence - Temperature sensitivity - Multimode interference patterns - Phase imbalance between outputs

           length_mmi
            <------>
            ________
           |        |
        __/          \__
    o2  __            __  o3
          \          /_ _ _ _
          |         | _ _ _ _| gap_output_tapers
        __/          \__
    o1  __            __  o4
          \          /
           |________|
         | |
         <->
    length_taper
# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wavelengths = np.linspace(1.5, 1.6, 101)
s = sax.models.mmi2x2(wl=wavelengths, fwhm=0.15, loss_dB=0.5)
bar_transmission = np.abs(s[("o1", "o3")]) ** 2
cross_transmission = np.abs(s[("o1", "o4")]) ** 2
plt.plot(wavelengths, bar_transmission, label="Bar")
plt.plot(wavelengths, cross_transmission, label="Cross")
plt.xlabel("Wavelength (μm)")
plt.ylabel("Transmission")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi2x2(
    wl: sax.FloatArrayLike = 1.55,
    wl0: sax.FloatArrayLike = 1.55,
    fwhm: sax.FloatArrayLike = 0.2,
    loss_dB: sax.FloatArrayLike = 0.3,
    shift: sax.FloatArrayLike = 0.005,
    loss_dB_cross: sax.FloatArrayLike | None = None,
    loss_dB_thru: sax.FloatArrayLike | None = None,
    splitting_ratio_cross: sax.FloatArrayLike = 0.5,
    splitting_ratio_thru: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Realistic 2x2 MMI coupler model with dispersion and asymmetry.

    This function models a realistic 2x2 multimode interference (MMI) coupler
    with wavelength-dependent behavior, insertion loss, and the ability to
    model fabrication-induced asymmetries between bar and cross ports.

    The MMI coupler can function as a directional coupler, 90-degree hybrid,
    or beam splitter depending on the design parameters. This model includes
    separate control over bar and cross port characteristics.

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Center wavelength in micrometers for optimal coupling performance.
            This is the design wavelength of the MMI device. Defaults to 1.55 μm.
        fwhm: Full width at half maximum bandwidth in micrometers. Determines
            the spectral width of good performance. Defaults to 0.2 μm.
        loss_dB: Base insertion loss in dB at the center wavelength. Used for
            both ports unless overridden. Defaults to 0.3 dB.
        shift: Wavelength shift in micrometers applied to both bar and cross
            responses. Models fabrication variations. Defaults to 0.005 μm.
        loss_dB_cross: Optional separate insertion loss in dB for cross ports.
            If None, uses loss_dB. Allows modeling of asymmetric loss.
        loss_dB_thru: Optional separate insertion loss in dB for bar (through)
            ports. If None, uses loss_dB. Allows modeling of asymmetric loss.
        splitting_ratio_cross: Power splitting ratio for cross ports (0 to 1).
            Allows modeling of imbalanced coupling. Defaults to 0.5.
        splitting_ratio_thru: Power splitting ratio for bar ports (0 to 1).
            Allows modeling of imbalanced transmission. Defaults to 0.5.

    Returns:
        S-matrix dictionary representing the realistic MMI coupler behavior.

    Examples:
        Symmetric 3dB MMI coupler:

        ```python
        import sax
        import numpy as np

        s_matrix = sax.models.mmi2x2(
            wl=1.55,
            fwhm=0.15,
            loss_dB=0.4,
            splitting_ratio_cross=0.5,
            splitting_ratio_thru=0.5,
        )
        bar_power = abs(s_matrix[("o1", "o3")]) ** 2
        cross_power = abs(s_matrix[("o1", "o4")]) ** 2
        ```

        Asymmetric MMI with different bar/cross losses:

        ```python
        s_matrix = sax.models.mmi2x2(
            wl=1.55,
            loss_dB_thru=0.3,  # Lower bar loss
            loss_dB_cross=0.6,  # Higher cross loss
            splitting_ratio_cross=0.4,  # Imbalanced coupling
            splitting_ratio_thru=0.6,
        )
        ```

        Wavelength-shifted MMI (fabrication variation):

        ```python
        s_matrix = sax.models.mmi2x2(
            wl=1.55,
            wl0=1.55,
            shift=0.01,  # 10 nm shift from process variation
            fwhm=0.12,
        )
        ```

        Spectral analysis of MMI coupler:

        ```python
        wavelengths = np.linspace(1.45, 1.65, 201)
        s_matrices = sax.models.mmi2x2(wl=wavelengths, wl0=1.55, fwhm=0.18, loss_dB=0.5)
        bar_transmission = np.abs(s_matrices[("o1", "o3")]) ** 2
        cross_transmission = np.abs(s_matrices[("o1", "o4")]) ** 2
        ```

    Note:
        The cross-coupled terms include a 90-degree phase shift (1j factor)
        to maintain proper S-matrix properties and represent the physical
        phase relationship in MMI devices.

        This model includes:
        - Separate spectral responses for bar and cross ports
        - Asymmetric loss modeling
        - Wavelength shift for process variations
        - Independent splitting ratio control

        The model assumes:
        - Gaussian spectral responses
        - Linear amplitude relationships
        - Reciprocal device behavior
        - No higher-order spectral features

        Real MMI devices may exhibit:
        - Non-Gaussian spectral shapes
        - Polarization dependence
        - Temperature sensitivity
        - Multimode interference patterns
        - Phase imbalance between outputs

    ```
               length_mmi
                <------>
                ________
               |        |
            __/          \__
        o2  __            __  o3
              \          /_ _ _ _
              |         | _ _ _ _| gap_output_tapers
            __/          \__
        o1  __            __  o4
              \          /
               |________|
             | |
             <->
        length_taper
    ```

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wavelengths = np.linspace(1.5, 1.6, 101)
    s = sax.models.mmi2x2(wl=wavelengths, fwhm=0.15, loss_dB=0.5)
    bar_transmission = np.abs(s[("o1", "o3")]) ** 2
    cross_transmission = np.abs(s[("o1", "o4")]) ** 2
    plt.plot(wavelengths, bar_transmission, label="Bar")
    plt.plot(wavelengths, cross_transmission, label="Cross")
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Transmission")
    plt.legend()
    ```
    """
    loss_dB_cross = loss_dB_cross or loss_dB
    loss_dB_thru = loss_dB_thru or loss_dB

    # Convert splitting ratios from power to amplitude by taking the square root
    amplitude_ratio_thru = splitting_ratio_thru**0.5
    amplitude_ratio_cross = splitting_ratio_cross**0.5

    # _mmi_amp already includes the loss, so we don't need to apply it again
    thru = (
        _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB_thru)
        * amplitude_ratio_thru
    )
    cross = (
        1j
        * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB_cross)
        * amplitude_ratio_cross
    )

    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o3): thru,
            (p.o1, p.o4): cross,
            (p.o2, p.o3): cross,
            (p.o2, p.o4): thru,
        }
    )

mmi2x2_ideal

mmi2x2_ideal(*, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 2x2 multimode interference (MMI) coupler model.

This function models an ideal 2x2 MMI coupler that can function as either a directional coupler or a 90-degree hybrid depending on the design. The model is implemented as an ideal coupler and represents the theoretical limit of MMI coupler performance.

MMI couplers use multimode interference effects in a wide waveguide to achieve controllable coupling between input and output ports. This ideal model assumes perfect operation without detailed multimode analysis.

Parameters:

Name Type Description Default
coupling FloatArrayLike

Power coupling coefficient between 0 and 1. Determines the fraction of power that couples between the two arms. - coupling=0.0: No coupling (straight transmission) - coupling=0.5: 3dB coupler (equal bar/cross transmission) - coupling=1.0: Complete cross-coupling Defaults to 0.5.

0.5

Returns:

Type Description
SDict

S-matrix dictionary representing the ideal MMI coupler behavior.

Examples:

3dB MMI coupler:

import sax

s_matrix = sax.models.mmi2x2_ideal(coupling=0.5)
bar_power = abs(s_matrix[("in0", "out0")]) ** 2  # Bar transmission
cross_power = abs(s_matrix[("in0", "out1")]) ** 2  # Cross transmission
print(f"Bar: {bar_power:.3f}, Cross: {cross_power:.3f}")

Variable coupling MMI:

s_matrix = sax.models.mmi2x2_ideal(coupling=0.8)  # 80% coupling
bar_power = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.2
cross_power = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.8
Note

This is an idealized model that assumes: - Perfect power conservation - Ideal 90-degree phase relationships - No wavelength dependence - No reflection or insertion loss

Real MMI couplers have wavelength-dependent coupling, finite bandwidth, and may deviate from ideal phase relationships. For more realistic modeling, use the dispersive mmi2x2() function.

Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi2x2_ideal(*, coupling: sax.FloatArrayLike = 0.5) -> sax.SDict:
    """Ideal 2x2 multimode interference (MMI) coupler model.

    This function models an ideal 2x2 MMI coupler that can function as either
    a directional coupler or a 90-degree hybrid depending on the design. The
    model is implemented as an ideal coupler and represents the theoretical
    limit of MMI coupler performance.

    MMI couplers use multimode interference effects in a wide waveguide to
    achieve controllable coupling between input and output ports. This ideal
    model assumes perfect operation without detailed multimode analysis.

    Args:
        coupling: Power coupling coefficient between 0 and 1. Determines the
            fraction of power that couples between the two arms.
            - coupling=0.0: No coupling (straight transmission)
            - coupling=0.5: 3dB coupler (equal bar/cross transmission)
            - coupling=1.0: Complete cross-coupling
            Defaults to 0.5.

    Returns:
        S-matrix dictionary representing the ideal MMI coupler behavior.

    Examples:
        3dB MMI coupler:

        ```python
        import sax

        s_matrix = sax.models.mmi2x2_ideal(coupling=0.5)
        bar_power = abs(s_matrix[("in0", "out0")]) ** 2  # Bar transmission
        cross_power = abs(s_matrix[("in0", "out1")]) ** 2  # Cross transmission
        print(f"Bar: {bar_power:.3f}, Cross: {cross_power:.3f}")
        ```

        Variable coupling MMI:

        ```python
        s_matrix = sax.models.mmi2x2_ideal(coupling=0.8)  # 80% coupling
        bar_power = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.2
        cross_power = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.8
        ```

    Note:
        This is an idealized model that assumes:
        - Perfect power conservation
        - Ideal 90-degree phase relationships
        - No wavelength dependence
        - No reflection or insertion loss

        Real MMI couplers have wavelength-dependent coupling, finite bandwidth,
        and may deviate from ideal phase relationships. For more realistic
        modeling, use the dispersive mmi2x2() function.
    """
    return coupler_ideal(coupling=coupling)

model_2port

model_2port(p1: Name, p2: Name) -> SDictModel

Generate a general 2-port model factory with custom port names.

This function creates a factory that generates simple 2-port devices with unity transmission between the specified ports. The resulting model provides ideal transmission with no loss, reflection, or wavelength dependence.

This factory is useful for creating custom 2-port models with specific port naming requirements or as a starting point for more complex models.

Parameters:

Name Type Description Default
p1 Name

Name of the first port (typically input).

required
p2 Name

Name of the second port (typically output).

required

Returns:

Type Description
SDictModel

A 2-port model

Examples:

Create a custom waveguide model:

import sax

# Create a 2-port model with custom port names
waveguide_model = sax.models.model_2port("input", "output")
s_matrix = waveguide_model(wl=1.55)
print(s_matrix[("input", "output")])  # Should be 1.0

Create multiple models with different port names:

isolator_model = sax.models.model_2port("in", "out")
delay_line_model = sax.models.model_2port("start", "end")

# Use in circuit simulations
s1 = isolator_model(wl=1.55)
s2 = delay_line_model(wl=1.55)
Note

The generated model assumes: - Perfect unity transmission (no loss) - No wavelength dependence - No reflection at either port - Bidirectional reciprocity

For more complex behavior, consider using the specific component models (straight, attenuator, phase_shifter) or implementing custom models with proper physical parameters.

Source code in src/sax/models/factories.py
@validate_call
def model_2port(p1: sax.Name, p2: sax.Name) -> sax.SDictModel:
    """Generate a general 2-port model factory with custom port names.

    This function creates a factory that generates simple 2-port devices with
    unity transmission between the specified ports. The resulting model provides
    ideal transmission with no loss, reflection, or wavelength dependence.

    This factory is useful for creating custom 2-port models with specific
    port naming requirements or as a starting point for more complex models.

    Args:
        p1: Name of the first port (typically input).
        p2: Name of the second port (typically output).

    Returns:
        A 2-port model

    Examples:
        Create a custom waveguide model:

        ```python
        import sax

        # Create a 2-port model with custom port names
        waveguide_model = sax.models.model_2port("input", "output")
        s_matrix = waveguide_model(wl=1.55)
        print(s_matrix[("input", "output")])  # Should be 1.0
        ```

        Create multiple models with different port names:

        ```python
        isolator_model = sax.models.model_2port("in", "out")
        delay_line_model = sax.models.model_2port("start", "end")

        # Use in circuit simulations
        s1 = isolator_model(wl=1.55)
        s2 = delay_line_model(wl=1.55)
        ```

    Note:
        The generated model assumes:
        - Perfect unity transmission (no loss)
        - No wavelength dependence
        - No reflection at either port
        - Bidirectional reciprocity

        For more complex behavior, consider using the specific component
        models (straight, attenuator, phase_shifter) or implementing
        custom models with proper physical parameters.
    """

    @jax.jit
    @validate_call
    def model_2port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        return sax.reciprocal({(p1, p2): jnp.ones_like(wl)})

    return model_2port

model_3port

model_3port(p1: Name, p2: Name, p3: Name) -> SDictModel

Generate a general 3-port model factory with custom port names.

This function creates a factory that generates 3-port devices implementing a 1-to-2 power splitter with equal splitting ratio. The model provides ideal 3dB splitting from the first port to the other two ports.

This factory is useful for creating custom splitter models with specific port naming or as a foundation for more complex 3-port devices.

Parameters:

Name Type Description Default
p1 Name

Name of the input port.

required
p2 Name

Name of the first output port.

required
p3 Name

Name of the second output port.

required

Returns:

Type Description
SDictModel

A 3-port model

Examples:

Create a custom Y-branch splitter:

import sax

y_branch_model = sax.models.model_3port("input", "branch1", "branch2")
s_matrix = y_branch_model(wl=1.55)

# Check equal splitting (3dB each)
power_branch1 = abs(s_matrix[("input", "branch1")]) ** 2
power_branch2 = abs(s_matrix[("input", "branch2")]) ** 2
print(f"Powers: {power_branch1:.3f}, {power_branch2:.3f}")  # Both 0.5

Create a tap coupler model:

tap_model = sax.models.model_3port("in", "thru", "drop")
s_matrix = tap_model(wl=1.55)
# Equal 50/50 tapping

Multi-wavelength simulation:

import numpy as np

wavelengths = np.linspace(1.5, 1.6, 101)
splitter_model = sax.models.model_3port("in", "out1", "out2")
s_matrices = splitter_model(wl=wavelengths)
# Wavelength-independent equal splitting
Note

The generated model implements: - Equal 3dB power splitting (50/50) - Unity amplitude coefficients of 1/√2 - No wavelength dependence - Perfect reciprocity - No reflection at any port

For wavelength-dependent or asymmetric splitting, consider using the splitter_ideal model with custom coupling ratios or implementing dispersive splitter models.

Source code in src/sax/models/factories.py
@validate_call
def model_3port(p1: sax.Name, p2: sax.Name, p3: sax.Name) -> sax.SDictModel:
    """Generate a general 3-port model factory with custom port names.

    This function creates a factory that generates 3-port devices implementing
    a 1-to-2 power splitter with equal splitting ratio. The model provides
    ideal 3dB splitting from the first port to the other two ports.

    This factory is useful for creating custom splitter models with specific
    port naming or as a foundation for more complex 3-port devices.

    Args:
        p1: Name of the input port.
        p2: Name of the first output port.
        p3: Name of the second output port.

    Returns:
        A 3-port model

    Examples:
        Create a custom Y-branch splitter:

        ```python
        import sax

        y_branch_model = sax.models.model_3port("input", "branch1", "branch2")
        s_matrix = y_branch_model(wl=1.55)

        # Check equal splitting (3dB each)
        power_branch1 = abs(s_matrix[("input", "branch1")]) ** 2
        power_branch2 = abs(s_matrix[("input", "branch2")]) ** 2
        print(f"Powers: {power_branch1:.3f}, {power_branch2:.3f}")  # Both 0.5
        ```

        Create a tap coupler model:

        ```python
        tap_model = sax.models.model_3port("in", "thru", "drop")
        s_matrix = tap_model(wl=1.55)
        # Equal 50/50 tapping
        ```

        Multi-wavelength simulation:

        ```python
        import numpy as np

        wavelengths = np.linspace(1.5, 1.6, 101)
        splitter_model = sax.models.model_3port("in", "out1", "out2")
        s_matrices = splitter_model(wl=wavelengths)
        # Wavelength-independent equal splitting
        ```

    Note:
        The generated model implements:
        - Equal 3dB power splitting (50/50)
        - Unity amplitude coefficients of 1/√2
        - No wavelength dependence
        - Perfect reciprocity
        - No reflection at any port

        For wavelength-dependent or asymmetric splitting, consider using
        the splitter_ideal model with custom coupling ratios or implementing
        dispersive splitter models.
    """

    @jax.jit
    @validate_call
    def model_3port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        thru = jnp.ones_like(wl) / jnp.sqrt(2)
        return sax.reciprocal(
            {
                (p1, p2): thru,
                (p1, p3): thru,
            }
        )

    return model_3port

model_4port

model_4port(p1: Name, p2: Name, p3: Name, p4: Name) -> SDictModel

Generate a general 4-port model factory with custom port names.

This function creates a factory that generates 4-port devices implementing a 2x2 directional coupler with 3dB coupling. The model provides ideal coupling behavior with proper 90-degree phase relationships.

The coupling configuration is: - Bar transmission: p1↔p4, p2↔p3 with amplitude 1/√2 - Cross coupling: p1↔p3, p2↔p4 with amplitude j/√2 (90° phase shift)

Parameters:

Name Type Description Default
p1 Name

Name of the first input port.

required
p2 Name

Name of the second input port.

required
p3 Name

Name of the first output port.

required
p4 Name

Name of the second output port.

required

Returns:

Type Description
SDictModel

A 4-port model

Examples:

Create a custom directional coupler:

import sax

dc_model = sax.models.model_4port("in1", "in2", "out1", "out2")
s_matrix = dc_model(wl=1.55)

# Check 3dB coupling
bar_power = abs(s_matrix[("in1", "out2")]) ** 2
cross_power = abs(s_matrix[("in1", "out1")]) ** 2
print(f"Bar: {bar_power:.3f}, Cross: {cross_power:.3f}")  # Both 0.5

# Check 90-degree phase relationship
bar_phase = jnp.angle(s_matrix[("in1", "out2")])
cross_phase = jnp.angle(s_matrix[("in1", "out1")])
phase_diff = cross_phase - bar_phase
print(f"Phase difference: {phase_diff:.3f} rad")  # Should be π/2

Create a 2x2 MMI model:

mmi_model = sax.models.model_4port("i1", "i2", "o1", "o2")
s_matrix = mmi_model(wl=1.55)

Multi-wavelength coupler analysis:

import numpy as np

wavelengths = np.linspace(1.5, 1.6, 101)
coupler_model = sax.models.model_4port("p1", "p2", "p3", "p4")
s_matrices = coupler_model(wl=wavelengths)
# Wavelength-independent 3dB coupling
Note

The generated model implements: - 3dB (50/50) power coupling - 90-degree phase shift for cross-coupled terms (1j factor) - Perfect power conservation (unitary S-matrix) - No wavelength dependence - Complete reciprocity

The S-matrix structure follows the standard directional coupler form: - Through paths: Real amplitude coefficients - Cross paths: Imaginary amplitude coefficients (j/√2)

For variable coupling ratios or dispersive behavior, use the coupler_ideal or coupler models with appropriate parameters.

Source code in src/sax/models/factories.py
@validate_call
def model_4port(
    p1: sax.Name, p2: sax.Name, p3: sax.Name, p4: sax.Name
) -> sax.SDictModel:
    """Generate a general 4-port model factory with custom port names.

    This function creates a factory that generates 4-port devices implementing
    a 2x2 directional coupler with 3dB coupling. The model provides ideal
    coupling behavior with proper 90-degree phase relationships.

    The coupling configuration is:
    - Bar transmission: p1↔p4, p2↔p3 with amplitude 1/√2
    - Cross coupling: p1↔p3, p2↔p4 with amplitude j/√2 (90° phase shift)

    Args:
        p1: Name of the first input port.
        p2: Name of the second input port.
        p3: Name of the first output port.
        p4: Name of the second output port.

    Returns:
        A 4-port model

    Examples:
        Create a custom directional coupler:

        ```python
        import sax

        dc_model = sax.models.model_4port("in1", "in2", "out1", "out2")
        s_matrix = dc_model(wl=1.55)

        # Check 3dB coupling
        bar_power = abs(s_matrix[("in1", "out2")]) ** 2
        cross_power = abs(s_matrix[("in1", "out1")]) ** 2
        print(f"Bar: {bar_power:.3f}, Cross: {cross_power:.3f}")  # Both 0.5

        # Check 90-degree phase relationship
        bar_phase = jnp.angle(s_matrix[("in1", "out2")])
        cross_phase = jnp.angle(s_matrix[("in1", "out1")])
        phase_diff = cross_phase - bar_phase
        print(f"Phase difference: {phase_diff:.3f} rad")  # Should be π/2
        ```

        Create a 2x2 MMI model:

        ```python
        mmi_model = sax.models.model_4port("i1", "i2", "o1", "o2")
        s_matrix = mmi_model(wl=1.55)
        ```

        Multi-wavelength coupler analysis:

        ```python
        import numpy as np

        wavelengths = np.linspace(1.5, 1.6, 101)
        coupler_model = sax.models.model_4port("p1", "p2", "p3", "p4")
        s_matrices = coupler_model(wl=wavelengths)
        # Wavelength-independent 3dB coupling
        ```

    Note:
        The generated model implements:
        - 3dB (50/50) power coupling
        - 90-degree phase shift for cross-coupled terms (1j factor)
        - Perfect power conservation (unitary S-matrix)
        - No wavelength dependence
        - Complete reciprocity

        The S-matrix structure follows the standard directional coupler form:
        - Through paths: Real amplitude coefficients
        - Cross paths: Imaginary amplitude coefficients (j/√2)

        For variable coupling ratios or dispersive behavior, use the
        coupler_ideal or coupler models with appropriate parameters.
    """

    @jax.jit
    @validate_call
    def model_4port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        thru = jnp.ones_like(wl) / jnp.sqrt(2)
        cross = 1j * thru
        return sax.reciprocal(
            {
                (p1, p4): thru,
                (p2, p3): thru,
                (p1, p3): cross,
                (p2, p4): cross,
            }
        )

    return model_4port

passthru

passthru(num_links: int, *, reciprocal: bool = True) -> SCooModel

Generate a multi-port pass-through device model.

This function creates a model for devices that provide direct one-to-one connections between input and output ports with no coupling between different channels. Each input port connects directly to its corresponding output port, implementing a diagonal S-matrix.

Pass-through devices are useful for modeling: - Fiber patch panels and interconnects - Optical switches in the "straight-through" state - Multi-channel delay lines or phase shifters - Ideal transmission links with no crosstalk

Parameters:

Name Type Description Default
num_links int

Number of independent pass-through links (input-output pairs). This creates a device with num_links inputs and num_links outputs.

required
reciprocal bool

If True, the device exhibits reciprocal behavior where transmission is identical in both directions. This is typical for passive devices like fibers and waveguides. Defaults to True.

True

Returns:

Type Description
SCooModel

A passthru model

Examples:

Create a 4-channel fiber ribbon cable:

import sax

fiber_ribbon = sax.models.passthru(4, reciprocal=True)
Si, Sj, Sx, port_map = fiber_ribbon(wl=1.55)
# 4 independent channels with unity transmission

Create an 8×8 optical switch (straight-through state):

switch_thru = sax.models.passthru(8, reciprocal=True)
Si, Sj, Sx, port_map = switch_thru(wl=1.55)
# Each input passes straight to corresponding output

Create a unidirectional buffer array:

buffer_array = sax.models.passthru(6, reciprocal=False)
Si, Sj, Sx, port_map = buffer_array(wl=1.55)
# One-way transmission for each channel

Multi-wavelength pass-through analysis:

import numpy as np

wavelengths = np.linspace(1.5, 1.6, 101)
multilink = sax.models.passthru(3, reciprocal=True)
Si, Sj, Sx, port_map = multilink(wl=wavelengths)
# Wavelength-independent transmission for all links
Note

The pass-through model creates a diagonal unitary matrix where: - Each input connects only to its corresponding output - No crosstalk between different channels - Unity transmission (lossless) for each link - Perfect isolation between channels

The S-matrix structure is:

S = [0  I]  where I is the identity matrix
    [I* 0]  and I* is I if reciprocal else 0

This model assumes: - Perfect channel isolation (infinite extinction ratio) - No wavelength dependence - No insertion loss - Ideal matching at all interfaces

For realistic multi-channel devices, consider adding: - Channel-to-channel crosstalk - Wavelength-dependent transmission - Insertion loss and reflection - Polarization effects

Applications include: - Modeling ideal optical fibers - Switch matrices in through state - Multi-channel delay lines - Parallel processing networks - Test and measurement setups

Source code in src/sax/models/factories.py
@validate_call
def passthru(
    num_links: int,
    *,
    reciprocal: bool = True,
) -> sax.SCooModel:
    """Generate a multi-port pass-through device model.

    This function creates a model for devices that provide direct one-to-one
    connections between input and output ports with no coupling between
    different channels. Each input port connects directly to its corresponding
    output port, implementing a diagonal S-matrix.

    Pass-through devices are useful for modeling:
    - Fiber patch panels and interconnects
    - Optical switches in the "straight-through" state
    - Multi-channel delay lines or phase shifters
    - Ideal transmission links with no crosstalk

    Args:
        num_links: Number of independent pass-through links (input-output pairs).
            This creates a device with num_links inputs and num_links outputs.
        reciprocal: If True, the device exhibits reciprocal behavior where
            transmission is identical in both directions. This is typical for
            passive devices like fibers and waveguides. Defaults to True.

    Returns:
        A passthru model

    Examples:
        Create a 4-channel fiber ribbon cable:

        ```python
        import sax

        fiber_ribbon = sax.models.passthru(4, reciprocal=True)
        Si, Sj, Sx, port_map = fiber_ribbon(wl=1.55)
        # 4 independent channels with unity transmission
        ```

        Create an 8×8 optical switch (straight-through state):

        ```python
        switch_thru = sax.models.passthru(8, reciprocal=True)
        Si, Sj, Sx, port_map = switch_thru(wl=1.55)
        # Each input passes straight to corresponding output
        ```

        Create a unidirectional buffer array:

        ```python
        buffer_array = sax.models.passthru(6, reciprocal=False)
        Si, Sj, Sx, port_map = buffer_array(wl=1.55)
        # One-way transmission for each channel
        ```

        Multi-wavelength pass-through analysis:

        ```python
        import numpy as np

        wavelengths = np.linspace(1.5, 1.6, 101)
        multilink = sax.models.passthru(3, reciprocal=True)
        Si, Sj, Sx, port_map = multilink(wl=wavelengths)
        # Wavelength-independent transmission for all links
        ```

    Note:
        The pass-through model creates a diagonal unitary matrix where:
        - Each input connects only to its corresponding output
        - No crosstalk between different channels
        - Unity transmission (lossless) for each link
        - Perfect isolation between channels

        The S-matrix structure is:
        ```
        S = [0  I]  where I is the identity matrix
            [I* 0]  and I* is I if reciprocal else 0
        ```

        This model assumes:
        - Perfect channel isolation (infinite extinction ratio)
        - No wavelength dependence
        - No insertion loss
        - Ideal matching at all interfaces

        For realistic multi-channel devices, consider adding:
        - Channel-to-channel crosstalk
        - Wavelength-dependent transmission
        - Insertion loss and reflection
        - Polarization effects

        Applications include:
        - Modeling ideal optical fibers
        - Switch matrices in through state
        - Multi-channel delay lines
        - Parallel processing networks
        - Test and measurement setups
    """
    passthru = unitary(
        num_links,
        num_links,
        reciprocal=reciprocal,
        diagonal=True,
    )
    passthru.__name__ = f"passthru_{num_links}_{num_links}"
    passthru.__qualname__ = f"passthru_{num_links}_{num_links}"
    return jax.jit(passthru)

phase_shifter

phase_shifter(
    wl: FloatArrayLike = 1.55,
    neff: FloatArrayLike = 2.34,
    voltage: FloatArrayLike = 0,
    length: FloatArrayLike = 10,
    loss: FloatArrayLike = 0.0,
) -> SDict

Simple voltage-controlled phase shifter model.

This function models a thermo-optic or electro-optic phase shifter that provides voltage-controlled phase modulation. The device introduces a voltage-dependent phase shift while maintaining (nearly) constant amplitude.

The model combines: - Passive waveguide propagation phase - Voltage-induced phase shift (linear relationship) - Optional loss from the active region

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. The phase shift has a 1/λ dependence from the propagation term. Defaults to 1.55 μm.

1.55
neff FloatArrayLike

Effective refractive index of the unperturbed waveguide mode. This determines the baseline propagation phase. Defaults to 2.34.

2.34
voltage FloatArrayLike

Applied voltage in volts. The phase shift is assumed to be linearly proportional to voltage with a coefficient of π rad/V. Positive voltage increases the phase. Defaults to 0 V.

0
length FloatArrayLike

Active length of the phase shifter in micrometers. Both the propagation phase and voltage-induced phase scale with length. Defaults to 10 μm.

10
loss FloatArrayLike

Additional loss in dB introduced by the active region (e.g., free-carrier absorption in silicon modulators). This loss is proportional to the device length. Defaults to 0.0 dB.

0.0

Returns:

Type Description
SDict

S-matrix dictionary containing the complex-valued transmission coefficient.

Examples:

Basic phase shifter:

import sax
import numpy as np

# No voltage applied
s_matrix = sax.models.phase_shifter(wl=1.55, voltage=0.0, length=100.0)
phase_0V = np.angle(s_matrix[("o1", "o2")])

# 1V applied
s_matrix = sax.models.phase_shifter(wl=1.55, voltage=1.0, length=100.0)
phase_1V = np.angle(s_matrix[("o1", "o2")])
voltage_phase_shift = phase_1V - phase_0V
print(f"Voltage-induced phase shift: {voltage_phase_shift:.3f} rad")

Phase modulation analysis:

voltages = np.linspace(-2, 2, 101)
phases = []
for v in voltages:
    s = sax.models.phase_shifter(voltage=v, length=50.0)
    phases.append(np.angle(s[("o1", "o2")]))
phases = np.array(phases)
# Should see linear relationship with slope = π * length

Lossy phase shifter (e.g., silicon modulator):

s_matrix = sax.models.phase_shifter(
    voltage=1.0,
    length=1000.0,  # 1 mm
    loss=0.5,  # 0.5 dB loss
)
transmission = abs(s_matrix[("o1", "o2")]) ** 2
phase_shift = np.angle(s_matrix[("o1", "o2")])
Note

This simplified model assumes: - Linear voltage-to-phase relationship (π rad/V) - Negligible voltage dependence of loss - No reflection from the active region - Perfect reciprocity

Real phase shifters may exhibit: - Nonlinear voltage response - Voltage-dependent loss - Bandwidth limitations - Temperature sensitivity - Polarization dependence

The model uses a simple π rad/V coefficient which may need calibration for specific devices and operating conditions.

Source code in src/sax/models/straight.py
@jax.jit
@validate_call
def phase_shifter(
    wl: sax.FloatArrayLike = 1.55,
    neff: sax.FloatArrayLike = 2.34,
    voltage: sax.FloatArrayLike = 0,
    length: sax.FloatArrayLike = 10,
    loss: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    """Simple voltage-controlled phase shifter model.

    This function models a thermo-optic or electro-optic phase shifter that
    provides voltage-controlled phase modulation. The device introduces a
    voltage-dependent phase shift while maintaining (nearly) constant amplitude.

    The model combines:
    - Passive waveguide propagation phase
    - Voltage-induced phase shift (linear relationship)
    - Optional loss from the active region

    Args:
        wl: Operating wavelength in micrometers. The phase shift has a 1/λ
            dependence from the propagation term. Defaults to 1.55 μm.
        neff: Effective refractive index of the unperturbed waveguide mode.
            This determines the baseline propagation phase. Defaults to 2.34.
        voltage: Applied voltage in volts. The phase shift is assumed to be
            linearly proportional to voltage with a coefficient of π rad/V.
            Positive voltage increases the phase. Defaults to 0 V.
        length: Active length of the phase shifter in micrometers. Both the
            propagation phase and voltage-induced phase scale with length.
            Defaults to 10 μm.
        loss: Additional loss in dB introduced by the active region (e.g.,
            free-carrier absorption in silicon modulators). This loss is
            proportional to the device length. Defaults to 0.0 dB.

    Returns:
        S-matrix dictionary containing the complex-valued transmission coefficient.

    Examples:
        Basic phase shifter:

        ```python
        import sax
        import numpy as np

        # No voltage applied
        s_matrix = sax.models.phase_shifter(wl=1.55, voltage=0.0, length=100.0)
        phase_0V = np.angle(s_matrix[("o1", "o2")])

        # 1V applied
        s_matrix = sax.models.phase_shifter(wl=1.55, voltage=1.0, length=100.0)
        phase_1V = np.angle(s_matrix[("o1", "o2")])
        voltage_phase_shift = phase_1V - phase_0V
        print(f"Voltage-induced phase shift: {voltage_phase_shift:.3f} rad")
        ```

        Phase modulation analysis:

        ```python
        voltages = np.linspace(-2, 2, 101)
        phases = []
        for v in voltages:
            s = sax.models.phase_shifter(voltage=v, length=50.0)
            phases.append(np.angle(s[("o1", "o2")]))
        phases = np.array(phases)
        # Should see linear relationship with slope = π * length
        ```

        Lossy phase shifter (e.g., silicon modulator):

        ```python
        s_matrix = sax.models.phase_shifter(
            voltage=1.0,
            length=1000.0,  # 1 mm
            loss=0.5,  # 0.5 dB loss
        )
        transmission = abs(s_matrix[("o1", "o2")]) ** 2
        phase_shift = np.angle(s_matrix[("o1", "o2")])
        ```

    Note:
        This simplified model assumes:
        - Linear voltage-to-phase relationship (π rad/V)
        - Negligible voltage dependence of loss
        - No reflection from the active region
        - Perfect reciprocity

        Real phase shifters may exhibit:
        - Nonlinear voltage response
        - Voltage-dependent loss
        - Bandwidth limitations
        - Temperature sensitivity
        - Polarization dependence

        The model uses a simple π rad/V coefficient which may need calibration
        for specific devices and operating conditions.
    """
    deltaphi = voltage * jnp.pi
    phase = 2 * jnp.pi * neff * length / wl + deltaphi
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.o1, p.o2): transmission,
        }
    )

splitter_ideal

splitter_ideal(*, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 1x2 power splitter model.

This function models an ideal 1x2 power splitter that divides input optical power between two output ports according to a specified coupling ratio. The model assumes no insertion loss and perfect reciprocity.

The splitter has three ports: - One input port (in0) - Two output ports (out0, out1)

Power splitting is determined by the coupling parameter: - Fraction (1-coupling) goes to out0 - Fraction coupling goes to out1 - Total power is conserved (no loss)

Parameters:

Name Type Description Default
coupling FloatArrayLike

Power coupling ratio between 0 and 1. Determines the fraction of input power that couples to the second output port (out1). - coupling=0.0: All power to out0 (no splitting) - coupling=0.5: Equal power splitting (3dB splitter) - coupling=1.0: All power to out1 Defaults to 0.5 for equal splitting.

0.5

Returns:

Type Description
SDict

S-matrix dictionary containing the complex-valued cross/thru coefficients.

Examples:

Equal power splitter (3dB splitter):

import sax

# 50/50 power splitter
s_matrix = sax.models.splitter_ideal(coupling=0.5)
print(f"Power to out0: {abs(s_matrix[('in0', 'out0')]) ** 2}")
print(f"Power to out1: {abs(s_matrix[('in0', 'out1')]) ** 2}")
# Both should be 0.5

Asymmetric splitter:

# 90/10 power splitter
s_matrix = sax.models.splitter_ideal(coupling=0.1)
power_out0 = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.9
power_out1 = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.1

Variable coupling analysis:

import numpy as np

couplings = np.linspace(0, 1, 101)
power_ratios = []
for c in couplings:
    s = sax.models.splitter_ideal(coupling=c)
    ratio = abs(s[("in0", "out1")]) ** 2 / abs(s[("in0", "out0")]) ** 2
    power_ratios.append(ratio)
Note

This is an idealized lossless model that assumes: - Perfect power conservation (no insertion loss) - No reflection at any port - Wavelength-independent behavior - Perfect reciprocity

Real splitters typically have some insertion loss (0.1-0.5 dB), wavelength-dependent splitting ratios, and may exhibit some reflection. The model uses amplitude coefficients (square root of power ratios) to ensure proper S-matrix properties.

Source code in src/sax/models/splitters.py
@jax.jit
@validate_call
def splitter_ideal(*, coupling: sax.FloatArrayLike = 0.5) -> sax.SDict:
    """Ideal 1x2 power splitter model.

    This function models an ideal 1x2 power splitter that divides input optical
    power between two output ports according to a specified coupling ratio.
    The model assumes no insertion loss and perfect reciprocity.

    The splitter has three ports:
    - One input port (in0)
    - Two output ports (out0, out1)

    Power splitting is determined by the coupling parameter:
    - Fraction (1-coupling) goes to out0
    - Fraction coupling goes to out1
    - Total power is conserved (no loss)

    Args:
        coupling: Power coupling ratio between 0 and 1. Determines the fraction
            of input power that couples to the second output port (out1).
            - coupling=0.0: All power to out0 (no splitting)
            - coupling=0.5: Equal power splitting (3dB splitter)
            - coupling=1.0: All power to out1
            Defaults to 0.5 for equal splitting.

    Returns:
        S-matrix dictionary containing the complex-valued cross/thru coefficients.

    Examples:
        Equal power splitter (3dB splitter):

        ```python
        import sax

        # 50/50 power splitter
        s_matrix = sax.models.splitter_ideal(coupling=0.5)
        print(f"Power to out0: {abs(s_matrix[('in0', 'out0')]) ** 2}")
        print(f"Power to out1: {abs(s_matrix[('in0', 'out1')]) ** 2}")
        # Both should be 0.5
        ```

        Asymmetric splitter:

        ```python
        # 90/10 power splitter
        s_matrix = sax.models.splitter_ideal(coupling=0.1)
        power_out0 = abs(s_matrix[("in0", "out0")]) ** 2  # Should be 0.9
        power_out1 = abs(s_matrix[("in0", "out1")]) ** 2  # Should be 0.1
        ```

        Variable coupling analysis:

        ```python
        import numpy as np

        couplings = np.linspace(0, 1, 101)
        power_ratios = []
        for c in couplings:
            s = sax.models.splitter_ideal(coupling=c)
            ratio = abs(s[("in0", "out1")]) ** 2 / abs(s[("in0", "out0")]) ** 2
            power_ratios.append(ratio)
        ```

    Note:
        This is an idealized lossless model that assumes:
        - Perfect power conservation (no insertion loss)
        - No reflection at any port
        - Wavelength-independent behavior
        - Perfect reciprocity

        Real splitters typically have some insertion loss (0.1-0.5 dB),
        wavelength-dependent splitting ratios, and may exhibit some reflection.
        The model uses amplitude coefficients (square root of power ratios)
        to ensure proper S-matrix properties.
    """
    kappa = jnp.asarray(coupling**0.5)
    tau = jnp.asarray((1 - coupling) ** 0.5)

    p = sax.PortNamer(num_inputs=1, num_outputs=2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): kappa,
        },
    )

unitary

unitary(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> SCooModel

Generate a unitary N×M optical device model.

This function creates a model for a unitary optical device that conserves power while providing controllable coupling between input and output ports. The model generates a physically realizable S-matrix through singular value decomposition and proper normalization.

Unitary devices are fundamental building blocks in photonics, including star couplers, multimode interference devices, and arbitrary unitary transformations for quantum photonic circuits.

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
reciprocal bool

If True, the device exhibits reciprocal behavior (S = S^T). This is typical for passive optical devices. Defaults to True.

True
diagonal bool

If True, creates a diagonal coupling matrix (each input couples to only one output). If False, creates full coupling between all input-output pairs. Defaults to False.

False

Returns:

Type Description
SCooModel

A unitary model

Examples:

Create a 4×4 star coupler:

import sax

star_coupler = sax.models.unitary(4, 4, reciprocal=True, diagonal=False)
Si, Sj, Sx, port_map = star_coupler(wl=1.55)
# Full 4×4 unitary matrix with all-to-all coupling

Create a 2×8 splitter array:

splitter_array = sax.models.unitary(2, 8, reciprocal=True)
Si, Sj, Sx, port_map = splitter_array(wl=1.55)
# Each input couples to all 8 outputs

Create diagonal routing device:

router = sax.models.unitary(4, 4, diagonal=True, reciprocal=True)
Si, Sj, Sx, port_map = router(wl=1.55)
# Each input couples to only one output (permutation matrix)

Non-reciprocal device (e.g., isolator array):

isolator_array = sax.models.unitary(3, 3, reciprocal=False)
Si, Sj, Sx, port_map = isolator_array(wl=1.55)
# Asymmetric transmission characteristics
Note

The algorithm works by: 1. Creating an initial coupling matrix based on the diagonal flag 2. Applying SVD to ensure unitary properties: S = U @ diag(s) @ V† 3. Normalizing to ensure power conservation 4. Extracting the specified input/output submatrix

The resulting S-matrix satisfies: - Power conservation: S† @ S = I (for square matrices) - Unitarity within the specified dimensions - Proper reciprocity if requested

This model is ideal for: - Prototyping arbitrary unitary transformations - Modeling star couplers and fan-out devices - Quantum photonic circuit elements - Network analysis of complex optical systems

For specific device physics, consider using dedicated models like MMI, coupler, or custom implementations with proper material and geometric parameters.

Source code in src/sax/models/factories.py
@validate_call
def unitary(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> sax.SCooModel:
    """Generate a unitary N×M optical device model.

    This function creates a model for a unitary optical device that conserves
    power while providing controllable coupling between input and output ports.
    The model generates a physically realizable S-matrix through singular value
    decomposition and proper normalization.

    Unitary devices are fundamental building blocks in photonics, including
    star couplers, multimode interference devices, and arbitrary unitary
    transformations for quantum photonic circuits.

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        reciprocal: If True, the device exhibits reciprocal behavior (S = S^T).
            This is typical for passive optical devices. Defaults to True.
        diagonal: If True, creates a diagonal coupling matrix (each input
            couples to only one output). If False, creates full coupling
            between all input-output pairs. Defaults to False.

    Returns:
        A unitary model

    Examples:
        Create a 4×4 star coupler:

        ```python
        import sax

        star_coupler = sax.models.unitary(4, 4, reciprocal=True, diagonal=False)
        Si, Sj, Sx, port_map = star_coupler(wl=1.55)
        # Full 4×4 unitary matrix with all-to-all coupling
        ```

        Create a 2×8 splitter array:

        ```python
        splitter_array = sax.models.unitary(2, 8, reciprocal=True)
        Si, Sj, Sx, port_map = splitter_array(wl=1.55)
        # Each input couples to all 8 outputs
        ```

        Create diagonal routing device:

        ```python
        router = sax.models.unitary(4, 4, diagonal=True, reciprocal=True)
        Si, Sj, Sx, port_map = router(wl=1.55)
        # Each input couples to only one output (permutation matrix)
        ```

        Non-reciprocal device (e.g., isolator array):

        ```python
        isolator_array = sax.models.unitary(3, 3, reciprocal=False)
        Si, Sj, Sx, port_map = isolator_array(wl=1.55)
        # Asymmetric transmission characteristics
        ```

    Note:
        The algorithm works by:
        1. Creating an initial coupling matrix based on the diagonal flag
        2. Applying SVD to ensure unitary properties: S = U @ diag(s) @ V†
        3. Normalizing to ensure power conservation
        4. Extracting the specified input/output submatrix

        The resulting S-matrix satisfies:
        - Power conservation: S† @ S = I (for square matrices)
        - Unitarity within the specified dimensions
        - Proper reciprocity if requested

        This model is ideal for:
        - Prototyping arbitrary unitary transformations
        - Modeling star couplers and fan-out devices
        - Quantum photonic circuit elements
        - Network analysis of complex optical systems

        For specific device physics, consider using dedicated models
        like MMI, coupler, or custom implementations with proper
        material and geometric parameters.
    """
    # let's create the squared S-matrix:
    N = max(num_inputs, num_outputs)
    S = jnp.zeros((2 * N, 2 * N), dtype=float)

    if not diagonal:
        S = S.at[:N, N:].set(1)
    else:
        r = jnp.arange(
            N,
            dtype=int,
        )  # reciprocal only works if num_inputs == num_outputs!
        S = S.at[r, N + r].set(1)

    if reciprocal:
        if not diagonal:
            S = S.at[N:, :N].set(1)
        else:
            r = jnp.arange(
                N,
                dtype=int,
            )  # reciprocal only works if num_inputs == num_outputs!
            S = S.at[N + r, r].set(1)

    # Now we need to normalize the squared S-matrix
    U, s, V = jnp.linalg.svd(S, full_matrices=False)
    S = jnp.sqrt(U @ jnp.diag(jnp.where(s > sax.EPS, 1, 0)) @ V)

    # Now create subset of this matrix we're interested in:
    r = jnp.concatenate(
        [jnp.arange(num_inputs, dtype=int), N + jnp.arange(num_outputs, dtype=int)],
        0,
    )
    S = S[r, :][:, r]

    # let's convert it in SCOO format:
    Si, Sj = jnp.where(S > sax.EPS)
    Sx = S[Si, Sj]

    # the last missing piece is a port map:
    p = sax.PortNamer(num_inputs, num_outputs)
    pm = {
        **{p[i]: i for i in range(num_inputs)},
        **{p[i + num_inputs]: i + num_inputs for i in range(num_outputs)},
    }

    @validate_call
    def func(*, wl: sax.FloatArrayLike = 1.5) -> sax.SCoo:
        wl_ = jnp.asarray(wl)
        Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))
        return Si, Sj, Sx_, pm

    func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
    func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
    return jax.jit(func)