Skip to content

models

SAX Models.

Modules:

Name Description
bends

SAX Bend Models.

couplers

SAX Coupler Models.

crossings

SAX Crossing Models.

factories

SAX Default Models.

mmis

SAX MMI Models.

splitters

SAX Default Models.

straight

SAX Default Models.

Functions:

Name Description
attenuator

Simple optical attenuator model.

bend

Simple waveguide bend model.

copier

Copy 100% of the power at the input ports to all output ports.

coupler

Dispersive directional coupler model.

coupler_ideal

Ideal 2x2 directional coupler model.

crossing_ideal

Ideal waveguide crossing model.

grating_coupler

Grating coupler model for fiber-chip coupling.

mmi1x2

Realistic 1x2 MMI splitter model with dispersion and loss.

mmi1x2_ideal

Ideal 1x2 multimode interference (MMI) splitter model.

mmi2x2

Realistic 2x2 MMI coupler model with dispersion and asymmetry.

mmi2x2_ideal

Ideal 2x2 multimode interference (MMI) coupler model.

model_2port

Generate a general 2-port model with 100% transmission.

model_3port

Generate a general 3-port model (1x2 splitter).

model_4port

Generate a general 4-port model (2x2 coupler).

passthru

Copy 100% of the power at each input port to its corresponding output port.

phase_shifter

Simple voltage-controlled phase shifter model.

splitter_ideal

Ideal 1x2 power splitter model.

unitary

Generate a unitary N×M optical device model.

attenuator

attenuator(*, wl: FloatArrayLike = WL_C, loss: FloatArrayLike = 0.0) -> SDict

Simple optical attenuator model.

in0 o1 out0 o2

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C
loss FloatArrayLike

Attenuation in decibels (dB).

0.0

Returns:

Type Description
SDict

S-matrix dictionary containing the complex transmission coefficient.

Examples:

Attenuator:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.attenuator(wl=wl, loss=3.0)
thru = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/straight.py
@jax.jit
@validate_call
def attenuator(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    loss: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    """Simple optical attenuator model.

    ```{svgbob}
    in0             out0
     o1 =========== o2
    ```

    Args:
        wl: The wavelength in micrometers.
        loss: Attenuation in decibels (dB).

    Returns:
        S-matrix dictionary containing the complex transmission coefficient.

    Examples:
        Attenuator:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.attenuator(wl=wl, loss=3.0)
        thru = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    transmission = jnp.asarray(10 ** (-loss / 20), dtype=complex) * one
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.out0): transmission,
        }
    )

bend

bend(
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    neff: FloatArrayLike = 2.34,
    ng: FloatArrayLike = 3.4,
    length: FloatArrayLike = 10.0,
    loss_dB_cm: FloatArrayLike = 0.1,
) -> SDict

Simple waveguide bend model.

out0 o2 o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

WL_C
wl0 FloatArrayLike

Reference wavelength in micrometers used for dispersion calculation. This is typically the design wavelength where neff is specified. Defaults to 1.55 μm.

WL_C
neff FloatArrayLike

Effective refractive index at the reference wavelength. This value represents the fundamental mode effective index and determines the phase velocity. Defaults to 2.34 (typical for silicon).

2.34
ng FloatArrayLike

Group refractive index at the reference wavelength. Used to calculate chromatic dispersion: ng = neff - lambda * d(neff)/d(lambda). Typically ng > neff for normal dispersion. Defaults to 3.4.

3.4
length FloatArrayLike

Physical length of the waveguide in micrometers. Determines both the total phase accumulation and loss. Defaults to 10.0 μm.

10.0
loss_dB_cm FloatArrayLike

Propagation loss in dB/cm. Includes material absorption, scattering losses, and other loss mechanisms. Set to 0.0 for lossless modeling. Defaults to 0.0 dB/cm.

0.1

Returns:

Type Description
SDict

The bend s-matrix

Examples:

Basic bend simulation:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.bend(wl=wl, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5)
thru = np.abs(s[("o1", "o2")]) ** 2

plt.plot(wl, thru)
plt.xlabel("Wavelength (μm)")
plt.ylabel("thru")
Note

This model treats the bend as an equivalent straight waveguide and does not account for:

  • Mode coupling between bend eigenmodes
  • Radiation losses due to bending
  • Polarization effects in bends
  • Non-linear dispersion effects

For more accurate bend modeling, consider using dedicated bend models that account for these physical effects.

Source code in src/sax/models/bends.py
@jax.jit
@validate_call
def bend(
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    neff: sax.FloatArrayLike = 2.34,
    ng: sax.FloatArrayLike = 3.4,
    length: sax.FloatArrayLike = 10.0,
    loss_dB_cm: sax.FloatArrayLike = 0.1,
) -> sax.SDict:
    """Simple waveguide bend model.

    ```{svgbob}
              out0
              o2

             /
         __.'
    o1
    in0
    ```

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Reference wavelength in micrometers used for dispersion calculation.
            This is typically the design wavelength where neff is specified.
            Defaults to 1.55 μm.
        neff: Effective refractive index at the reference wavelength. This value
            represents the fundamental mode effective index and determines the
            phase velocity. Defaults to 2.34 (typical for silicon).
        ng: Group refractive index at the reference wavelength. Used to calculate
            chromatic dispersion: ng = neff - lambda * d(neff)/d(lambda).
            Typically ng > neff for normal dispersion. Defaults to 3.4.
        length: Physical length of the waveguide in micrometers. Determines both
            the total phase accumulation and loss. Defaults to 10.0 μm.
        loss_dB_cm: Propagation loss in dB/cm. Includes material absorption,
            scattering losses, and other loss mechanisms. Set to 0.0 for
            lossless modeling. Defaults to 0.0 dB/cm.

    Returns:
        The bend s-matrix

    Examples:
        Basic bend simulation:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax
    sax.set_port_naming_strategy("optical")

    wl = sax.wl_c()
    s = sax.models.bend(wl=wl, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5)
    thru = np.abs(s[("o1", "o2")]) ** 2

    plt.plot(wl, thru)
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("thru")
    ```

    Note:
        This model treats the bend as an equivalent straight waveguide and does not
        account for:

        - Mode coupling between bend eigenmodes
        - Radiation losses due to bending
        - Polarization effects in bends
        - Non-linear dispersion effects

        For more accurate bend modeling, consider using dedicated bend models that
        account for these physical effects.
    """
    return straight(
        wl=wl,
        wl0=wl0,
        neff=neff,
        ng=ng,
        length=length,
        loss_dB_cm=loss_dB_cm,
    )

copier

copier(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> SCooModel

Copy 100% of the power at the input ports to all output ports.

o ni 1 o ni . . . o2 in1 o N 2 o1 in0 o N 1

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
reciprocal bool

If True, the device exhibits reciprocal behavior where forward and reverse transmissions are equal. Defaults to True.

True
diagonal bool

If True, creates diagonal coupling (each input couples to only one output). If False, creates full coupling between all input-output pairs. Defaults to False.

False

Returns:

Type Description
SCooModel

A copier model

Examples:

A 1×4 optical amplifier/splitter:

import sax

amp_splitter = sax.models.copier(1, 4, reciprocal=False)
Si, Sj, Sx, port_map = amp_splitter(wl=1.55)
# Single input amplified and split to 4 outputs
Source code in src/sax/models/factories.py
@validate_call
def copier(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> sax.SCooModel:
    """Copy 100% of the power at the input ports to all output ports.

    ```{svgbob}
    o_ni -1 *---+-----+--* o_ni
                |     |
                  .
                  .
                  .
                |     |
         o2 *---+-----+--* o_N -2
        in1     |     |
                |     |
         o1 *---+-----+--* o_N -1
        in0
    ```

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        reciprocal: If True, the device exhibits reciprocal behavior where
            forward and reverse transmissions are equal. Defaults to True.
        diagonal: If True, creates diagonal coupling (each input couples to
            only one output). If False, creates full coupling between all
            input-output pairs. Defaults to False.

    Returns:
        A copier model

    Examples:
        A 1×4 optical amplifier/splitter:

        ```python
        import sax

        amp_splitter = sax.models.copier(1, 4, reciprocal=False)
        Si, Sj, Sx, port_map = amp_splitter(wl=1.55)
        # Single input amplified and split to 4 outputs
        ```

    """
    # let's create the squared S-matrix:
    S = jnp.zeros((num_inputs + num_outputs, num_inputs + num_outputs), dtype=float)

    if not diagonal:
        S = S.at[:num_inputs, num_inputs:].set(1)
    else:
        r = jnp.arange(
            num_inputs,
            dtype=int,
        )  # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs!
        S = S.at[r, num_inputs + r].set(1)

    if reciprocal:
        if not diagonal:
            S = S.at[num_inputs:, :num_inputs].set(1)
        else:
            # reciprocal only works if num_inputs == num_outputs!
            r = jnp.arange(num_inputs, dtype=int)  # == range(num_outputs)
            S = S.at[num_inputs + r, r].set(1)

    # let's convert it in SCOO format:
    Si, Sj = jnp.where(jnp.sqrt(sax.EPS) < S)
    Sx = S[Si, Sj]

    # the last missing piece is a port map:
    p = sax.PortNamer(num_inputs, num_outputs)
    pm = {
        **{p[i]: i for i in range(num_inputs)},
        **{p[i + num_inputs]: i + num_inputs for i in range(num_outputs)},
    }

    @validate_call
    def func(wl: sax.FloatArrayLike = 1.5) -> sax.SCoo:
        wl_ = jnp.asarray(wl)
        Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))
        return Si, Sj, Sx_, pm

    func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
    func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
    return jax.jit(func)

coupler

coupler(
    *,
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    length: FloatArrayLike = 0.0,
    coupling0: FloatArrayLike = 0.2,
    dk1: FloatArrayLike = 1.2435,
    dk2: FloatArrayLike = 5.3022,
    dn: FloatArrayLike = 0.02,
    dn1: FloatArrayLike = 0.1169,
    dn2: FloatArrayLike = 0.4821,
) -> SDict

Dispersive directional coupler model.

in1 out1 o2 o3 2 coupling coupling0 length o1 o4 in0 out0 coupling0 2

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

WL_C
wl0 FloatArrayLike

Reference wavelength in micrometers for dispersion expansion. Typically the center design wavelength. Defaults to 1.55 μm.

WL_C
length FloatArrayLike

Coupling length in micrometers. This is the length over which the two waveguides are in close proximity. Defaults to 0.0 μm.

0.0
coupling0 FloatArrayLike

Base coupling coefficient at the reference wavelength from bend regions or other coupling mechanisms. Obtained from FDTD simulations or measurements. Defaults to 0.2.

0.2
dk1 FloatArrayLike

First-order derivative of coupling coefficient with respect to wavelength (∂κ/∂λ). Units: μm⁻¹. Defaults to 1.2435.

1.2435
dk2 FloatArrayLike

Second-order derivative of coupling coefficient with respect to wavelength (∂²κ/∂λ²). Units: μm⁻². Defaults to 5.3022.

5.3022
dn FloatArrayLike

Effective index difference between even and odd supermodes at the reference wavelength. Determines beating length. Defaults to 0.02.

0.02
dn1 FloatArrayLike

First-order derivative of effective index difference with respect to wavelength (∂Δn/∂λ). Units: μm⁻¹. Defaults to 0.1169.

0.1169
dn2 FloatArrayLike

Second-order derivative of effective index difference with respect to wavelength (∂²Δn/∂λ²). Units: μm⁻². Defaults to 0.4821.

0.4821

Returns:

Type Description
SDict

The coupler s-matrix

Examples:

Basic dispersive coupler:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.coupler(
    wl=wl,
    length=15.7,
    coupling0=0.0,
    dn=0.02,
)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Note

The coupling strength follows the formula:

κ_total = κ₀ + κ₁ + κ_length

Where:

  • κ₀ = coupling0 + dk1(λ-λ₀) + 0.5dk2*(λ-λ₀)²
  • κ₁ = π*Δn(λ)/λ
  • κ_length represents the distributed coupling over the interaction length
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def coupler(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    length: sax.FloatArrayLike = 0.0,
    coupling0: sax.FloatArrayLike = 0.2,
    dk1: sax.FloatArrayLike = 1.2435,
    dk2: sax.FloatArrayLike = 5.3022,
    dn: sax.FloatArrayLike = 0.02,
    dn1: sax.FloatArrayLike = 0.1169,
    dn2: sax.FloatArrayLike = 0.4821,
) -> sax.SDict:
    r"""Dispersive directional coupler model.

    ```{svgbob}
            in1                out1
             o2                o3
              *                *
               \              /
                '------------'
    coupling0 /2   coupling   coupling0 /2
                .------------.
               /  <-length -> \
              *                *
             o1                o4
            in0                out0
    ```

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Reference wavelength in micrometers for dispersion expansion.
            Typically the center design wavelength. Defaults to 1.55 μm.
        length: Coupling length in micrometers. This is the length over which
            the two waveguides are in close proximity. Defaults to 0.0 μm.
        coupling0: Base coupling coefficient at the reference wavelength from
            bend regions or other coupling mechanisms. Obtained from FDTD
            simulations or measurements. Defaults to 0.2.
        dk1: First-order derivative of coupling coefficient with respect to
            wavelength (∂κ/∂λ). Units: μm⁻¹. Defaults to 1.2435.
        dk2: Second-order derivative of coupling coefficient with respect to
            wavelength (∂²κ/∂λ²). Units: μm⁻². Defaults to 5.3022.
        dn: Effective index difference between even and odd supermodes at
            the reference wavelength. Determines beating length. Defaults to 0.02.
        dn1: First-order derivative of effective index difference with respect
            to wavelength (∂Δn/∂λ). Units: μm⁻¹. Defaults to 0.1169.
        dn2: Second-order derivative of effective index difference with respect
            to wavelength (∂²Δn/∂λ²). Units: μm⁻². Defaults to 0.4821.

    Returns:
        The coupler s-matrix

    Examples:
        Basic dispersive coupler:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wl = sax.wl_c()
    s = sax.models.coupler(
        wl=wl,
        length=15.7,
        coupling0=0.0,
        dn=0.02,
    )
    thru = np.abs(s[("o1", "o4")]) ** 2
    cross = np.abs(s[("o1", "o3")]) ** 2
    plt.figure()
    plt.plot(wl, thru, label="thru")
    plt.plot(wl, cross, label="cross")
    plt.xlabel("Wavelength [μm]")
    plt.ylabel("Power")
    plt.legend()
    ```

    Note:
        The coupling strength follows the formula:

        κ_total = κ₀ + κ₁ + κ_length

        Where:

        - κ₀ = coupling0 + dk1*(λ-λ₀) + 0.5*dk2*(λ-λ₀)²
        - κ₁ = π*Δn(λ)/λ
        - κ_length represents the distributed coupling over the interaction length

    """
    dwl = wl - wl0
    dn = dn + dn1 * dwl + 0.5 * dn2 * dwl**2
    kappa0 = coupling0 + dk1 * dwl + 0.5 * dk2 * dwl**2
    kappa1 = jnp.pi * dn / wl

    tau = jnp.cos(kappa0 + kappa1 * length)
    kappa = -jnp.sin(kappa0 + kappa1 * length)
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): 1j * kappa,
            (p.in1, p.out0): 1j * kappa,
            (p.in1, p.out1): tau,
        }
    )

coupler_ideal

coupler_ideal(*, wl: FloatArrayLike = WL_C, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 2x2 directional coupler model.

in1 out1 o2 o3 coupling o1 o4 in0 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

the wavelength of the simulation in micrometers.

WL_C
coupling FloatArrayLike

Power coupling coefficient between 0 and 1. Defaults to 0.5.

0.5

Returns:

Type Description
SDict

The coupler s-matrix

Examples:

Ideal coupler:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.coupler_ideal(
    wl=wl,
    coupling=0.3,
)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def coupler_ideal(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    coupling: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Ideal 2x2 directional coupler model.

    ```{svgbob}
     in1          out1
      o2          o3
       *          *
        \        /
         '------'
         coupling
         .------.
        /        \
       *          *
      o1          o4
     in0          out0
    ```

    Args:
        wl: the wavelength of the simulation in micrometers.
        coupling: Power coupling coefficient between 0 and 1. Defaults to 0.5.

    Returns:
        The coupler s-matrix

    Examples:
        Ideal coupler:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.coupler_ideal(
            wl=wl,
            coupling=0.3,
        )
        thru = np.abs(s[("o1", "o4")]) ** 2
        cross = np.abs(s[("o1", "o3")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(wl)
    kappa = jnp.asarray(coupling**0.5) * one
    tau = jnp.asarray((1 - coupling) ** 0.5) * one
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): 1j * kappa,
            (p.in1, p.out0): 1j * kappa,
            (p.in1, p.out1): tau,
        },
    )

crossing_ideal

crossing_ideal(wl: FloatArrayLike = WL_C) -> SDict

Ideal waveguide crossing model.

in1 o2 o3 out0 o4 out1 o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C

Returns:

Type Description
SDict

The crossing s-matrix

Examples:

Ideal crossing:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.crossing_ideal(wl=wl)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s.get(("o1", "o2"), jnp.zeros_like(wl))) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/crossings.py
@jax.jit
@validate_call
def crossing_ideal(wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal waveguide crossing model.

    ```{svgbob}
            in1
             o2
              *
              |
              |
     o1 *-----+-----* o3
    in0       |     out0
              |
              *
             o4
           out1
    ```

    Args:
        wl: Wavelength in micrometers.

    Returns:
        The crossing s-matrix

    Examples:
        Ideal crossing:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import jax.numpy as jnp
    import sax

    sax.set_port_naming_strategy("optical")

    wl = sax.wl_c()
    s = sax.models.crossing_ideal(wl=wl)
    thru = np.abs(s[("o1", "o3")]) ** 2
    cross = np.abs(s.get(("o1", "o2"), jnp.zeros_like(wl))) ** 2
    plt.figure()
    plt.plot(wl, thru, label="thru")
    plt.plot(wl, cross, label="cross")
    plt.xlabel("Wavelength [μm]")
    plt.ylabel("Power")
    plt.legend()
    ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o3): one,
            (p.o2, p.o4): one,
        }
    )

grating_coupler

grating_coupler(
    *,
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    loss: FloatArrayLike = 0.0,
    reflection: FloatArrayLike = 0.0,
    reflection_fiber: FloatArrayLike = 0.0,
    bandwidth: FloatArrayLike = 0.04,
) -> SDict

Grating coupler model for fiber-chip coupling.

out0 o2 fiber o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for spectral analysis. Defaults to 1.55 μm.

WL_C
wl0 FloatArrayLike

Center wavelength in micrometers where peak transmission occurs. This is the design wavelength of the grating. Defaults to 1.55 μm.

WL_C
loss FloatArrayLike

Insertion loss in dB at the center wavelength. Includes coupling efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.

0.0
reflection FloatArrayLike

Reflection coefficient from the waveguide side (chip interface). Represents reflections back into the waveguide from grating discontinuities. Range: 0 to 1. Defaults to 0.0.

0.0
reflection_fiber FloatArrayLike

Reflection coefficient from the fiber side (top interface). Represents reflections back toward the fiber from the grating surface. Range: 0 to 1. Defaults to 0.0.

0.0
bandwidth FloatArrayLike

3dB bandwidth in micrometers. Determines the spectral width of the Gaussian transmission profile. Typical values: 20-50 nm. Defaults to 40e-3 μm (40 nm).

0.04

Returns:

Type Description
SDict

The grating coupler s-matrix

Examples:

Basic grating coupler:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = np.linspace(1.5, 1.6, 101)
s = sax.models.grating_coupler(
    wl=wl,
    loss=3.0,
    bandwidth=0.035,
)
plt.figure()
plt.plot(wl, np.abs(s[("o1", "o2")]) ** 2, label="Transmission")
plt.xlabel("Wavelength (μm)")
plt.ylabel("Power")
plt.legend()
Note

The transmission profile follows a Gaussian shape: T(λ) = T₀ * exp(-((λ-λ₀)/σ)²)

Where σ = bandwidth / (2√(2ln(2))) converts FWHM to Gaussian width.

This model assumes:

  • Gaussian spectral response (typical for uniform gratings)
  • Wavelength-independent loss coefficient
  • Linear polarization (TE or TM)
  • Single-mode fiber coupling
  • No higher-order diffraction effects
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def grating_coupler(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    loss: sax.FloatArrayLike = 0.0,
    reflection: sax.FloatArrayLike = 0.0,
    reflection_fiber: sax.FloatArrayLike = 0.0,
    bandwidth: sax.FloatArrayLike = 40e-3,
) -> sax.SDict:
    """Grating coupler model for fiber-chip coupling.

    ```{svgbob}
                  out0
                  o2
             /  /  /  /
            /  /  /  /  fiber
           /  /  /  /
           _   _   _
     o1  _| |_| |_| |__
    in0
    ```

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            spectral analysis. Defaults to 1.55 μm.
        wl0: Center wavelength in micrometers where peak transmission occurs.
            This is the design wavelength of the grating. Defaults to 1.55 μm.
        loss: Insertion loss in dB at the center wavelength. Includes coupling
            efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.
        reflection: Reflection coefficient from the waveguide side (chip interface).
            Represents reflections back into the waveguide from grating discontinuities.
            Range: 0 to 1. Defaults to 0.0.
        reflection_fiber: Reflection coefficient from the fiber side (top interface).
            Represents reflections back toward the fiber from the grating surface.
            Range: 0 to 1. Defaults to 0.0.
        bandwidth: 3dB bandwidth in micrometers. Determines the spectral width
            of the Gaussian transmission profile. Typical values: 20-50 nm.
            Defaults to 40e-3 μm (40 nm).

    Returns:
        The grating coupler s-matrix

    Examples:
        Basic grating coupler:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wl = np.linspace(1.5, 1.6, 101)
    s = sax.models.grating_coupler(
        wl=wl,
        loss=3.0,
        bandwidth=0.035,
    )
    plt.figure()
    plt.plot(wl, np.abs(s[("o1", "o2")]) ** 2, label="Transmission")
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Power")
    plt.legend()
    ```

    Note:
        The transmission profile follows a Gaussian shape:
        T(λ) = T₀ * exp(-((λ-λ₀)/σ)²)

        Where σ = bandwidth / (2*√(2*ln(2))) converts FWHM to Gaussian width.

        This model assumes:

        - Gaussian spectral response (typical for uniform gratings)
        - Wavelength-independent loss coefficient
        - Linear polarization (TE or TM)
        - Single-mode fiber coupling
        - No higher-order diffraction effects

    """
    one = jnp.ones_like(wl)
    reflection = jnp.asarray(reflection) * one
    reflection_fiber = jnp.asarray(reflection_fiber) * one
    amplitude = jnp.asarray(10 ** (-loss / 20))
    sigma = jnp.asarray(bandwidth / (2 * jnp.sqrt(2 * jnp.log(2))))
    transmission = jnp.asarray(amplitude * jnp.exp(-((wl - wl0) ** 2) / (2 * sigma**2)))
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.in0): reflection,
            (p.in0, p.out0): transmission,
            (p.out0, p.in0): transmission,
            (p.out0, p.out0): reflection_fiber,
        }
    )

mmi1x2

mmi1x2(
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    fwhm: FloatArrayLike = 0.2,
    loss_dB: FloatArrayLike = 0.3,
) -> SDict

Realistic 1x2 MMI splitter model with dispersion and loss.

out1 o2 o1 in0 o3 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C
wl0 FloatArrayLike

The Center wavelength of the MMI

WL_C
fwhm FloatArrayLike

The Full width at half maximum or the MMI.

0.2
loss_dB FloatArrayLike

Insertion loss in dB at the center wavelength.

0.3

Returns:

Type Description
SDict

S-matrix dictionary representing the dispersive MMI splitter behavior.

Examples:

Basic 1x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi1x2(wl=wl)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi1x2(
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    fwhm: sax.FloatArrayLike = 0.2,
    loss_dB: sax.FloatArrayLike = 0.3,
) -> sax.SDict:
    r"""Realistic 1x2 MMI splitter model with dispersion and loss.

    ```{svgbob}
            +-------------+     out1
            |             |---* o2
     o1 *---|             |
    in0     |             |---* o3
            +-------------+     out0
    ```

    Args:
        wl: The wavelength in micrometers.
        wl0: The Center wavelength of the MMI
        fwhm: The Full width at half maximum or the MMI.
        loss_dB: Insertion loss in dB at the center wavelength.

    Returns:
        S-matrix dictionary representing the dispersive MMI splitter behavior.

    Examples:
        Basic 1x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi1x2(wl=wl)
        thru = np.abs(s[("o1", "o3")]) ** 2
        cross = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB) / 2**0.5

    p = sax.PortNamer(1, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o2): thru,
            (p.o1, p.o3): thru,
        }
    )

mmi1x2_ideal

mmi1x2_ideal(wl: FloatArrayLike = WL_C) -> SDict

Ideal 1x2 multimode interference (MMI) splitter model.

out1 o2 o1 in0 o3 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C

Returns:

Type Description
SDict

S-matrix dictionary representing the ideal MMI splitter behavior.

Examples:

Ideal 1x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi1x2_ideal(wl=wl)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi1x2_ideal(wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal 1x2 multimode interference (MMI) splitter model.

    ```{svgbob}
            +-------------+     out1
            |             |---* o2
     o1 *---|             |
    in0     |             |---* o3
            +-------------+     out0
    ```

    Args:
        wl: The wavelength in micrometers.

    Returns:
        S-matrix dictionary representing the ideal MMI splitter behavior.

    Examples:
        Ideal 1x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi1x2_ideal(wl=wl)
        thru = np.abs(s[("o1", "o3")]) ** 2
        cross = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    return splitter_ideal(wl=wl, coupling=0.5)

mmi2x2

mmi2x2(
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    fwhm: FloatArrayLike = 0.2,
    loss_dB: FloatArrayLike = 0.3,
    shift: FloatArrayLike = 0.0,
    loss_dB_cross: FloatArrayLike | None = None,
    loss_dB_thru: FloatArrayLike | None = None,
    splitting_ratio_cross: FloatArrayLike = 0.5,
    splitting_ratio_thru: FloatArrayLike = 0.5,
) -> SDict

Realistic 2x2 MMI coupler model with dispersion and asymmetry.

in1 o2 out1 o3 o1 in0 o4 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

wavelength in micrometers.

WL_C
wl0 FloatArrayLike

Center wavelength of the MMI in micrometers.

WL_C
fwhm FloatArrayLike

Full width at half maximum bandwidth in micrometers.

0.2
loss_dB FloatArrayLike

Insertion loss of the MMI at center wavelength.

0.3
shift FloatArrayLike

The peak shift of the cross-transmission in micrometers.

0.0
loss_dB_cross FloatArrayLike | None

Optional separate insertion loss in dB for cross ports. If None, uses loss_dB. Allows modeling of asymmetric loss.

None
loss_dB_thru FloatArrayLike | None

Optional separate insertion loss in dB for bar (through) ports. If None, uses loss_dB. Allows modeling of asymmetric loss.

None
splitting_ratio_cross FloatArrayLike

Power splitting ratio for cross ports (0 to 1). Allows modeling of imbalanced coupling. Defaults to 0.5.

0.5
splitting_ratio_thru FloatArrayLike

Power splitting ratio for bar ports (0 to 1). Allows modeling of imbalanced transmission. Defaults to 0.5.

0.5

Returns:

Type Description
SDict

S-matrix dictionary representing the realistic MMI coupler behavior.

Examples:

Basic 2x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi2x2(wl=wl, shift=0.001)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()

```

Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi2x2(
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    fwhm: sax.FloatArrayLike = 0.2,
    loss_dB: sax.FloatArrayLike = 0.3,
    shift: sax.FloatArrayLike = 0.0,
    loss_dB_cross: sax.FloatArrayLike | None = None,
    loss_dB_thru: sax.FloatArrayLike | None = None,
    splitting_ratio_cross: sax.FloatArrayLike = 0.5,
    splitting_ratio_thru: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Realistic 2x2 MMI coupler model with dispersion and asymmetry.

    ```{svgbob}
    in1     +-------------+     out1
     o2 *---|             |---* o3
            |             |
     o1 *---|             |---* o4
    in0     +-------------+     out0
    ```

    Args:
        wl: wavelength in micrometers.
        wl0: Center wavelength of the MMI in micrometers.
        fwhm: Full width at half maximum bandwidth in micrometers.
        loss_dB: Insertion loss of the MMI at center wavelength.
        shift: The peak shift of the cross-transmission in micrometers.
        loss_dB_cross: Optional separate insertion loss in dB for cross ports.
            If None, uses loss_dB. Allows modeling of asymmetric loss.
        loss_dB_thru: Optional separate insertion loss in dB for bar (through)
            ports. If None, uses loss_dB. Allows modeling of asymmetric loss.
        splitting_ratio_cross: Power splitting ratio for cross ports (0 to 1).
            Allows modeling of imbalanced coupling. Defaults to 0.5.
        splitting_ratio_thru: Power splitting ratio for bar ports (0 to 1).
            Allows modeling of imbalanced transmission. Defaults to 0.5.

    Returns:
        S-matrix dictionary representing the realistic MMI coupler behavior.

    Examples:
        Basic 2x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi2x2(wl=wl, shift=0.001)
        thru = np.abs(s[("o1", "o4")]) ** 2
        cross = np.abs(s[("o1", "o3")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    ```
    """
    loss_dB_cross = loss_dB_cross or loss_dB
    loss_dB_thru = loss_dB_thru or loss_dB

    # Convert splitting ratios from power to amplitude by taking the square root
    amplitude_ratio_thru = splitting_ratio_thru**0.5
    amplitude_ratio_cross = splitting_ratio_cross**0.5

    # _mmi_amp already includes the loss, so we don't need to apply it again
    thru = (
        _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB_thru) * amplitude_ratio_thru
    )
    cross = (
        1j
        * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB_cross)
        * amplitude_ratio_cross
    )

    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o3): thru,
            (p.o1, p.o4): cross,
            (p.o2, p.o3): cross,
            (p.o2, p.o4): thru,
        }
    )

mmi2x2_ideal

mmi2x2_ideal(*, wl: FloatArrayLike = WL_C) -> SDict

Ideal 2x2 multimode interference (MMI) coupler model.

in1 o2 out1 o3 o1 in0 o4 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C

Returns:

Type Description
SDict

S-matrix dictionary representing the ideal MMI coupler behavior.

Examples:

Ideal 2x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi2x2_ideal(wl=wl)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi2x2_ideal(*, wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal 2x2 multimode interference (MMI) coupler model.

    ```{svgbob}
    in1     +-------------+     out1
     o2 *---|             |---* o3
            |             |
     o1 *---|             |---* o4
    in0     +-------------+     out0
    ```

    Args:
        wl: The wavelength in micrometers.

    Returns:
        S-matrix dictionary representing the ideal MMI coupler behavior.

    Examples:
        Ideal 2x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi2x2_ideal(wl=wl)
        thru = np.abs(s[("o1", "o4")]) ** 2
        cross = np.abs(s[("o1", "o3")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    return coupler_ideal(wl=wl, coupling=0.5)

model_2port

model_2port(p1: Name, p2: Name) -> SDictModel

Generate a general 2-port model with 100% transmission.

p1 p2

Parameters:

Name Type Description Default
p1 Name

Name of the first port (typically in0 or o1).

required
p2 Name

Name of the second port (typically out0 or o2).

required

Returns:

Type Description
SDictModel

A 2-port model

Source code in src/sax/models/factories.py
@validate_call
def model_2port(p1: sax.Name, p2: sax.Name) -> sax.SDictModel:
    """Generate a general 2-port model with 100% transmission.

    ```{svgbob}
     p1 *--------* p2
    ```

    Args:
        p1: Name of the first port (typically in0 or o1).
        p2: Name of the second port (typically out0 or o2).

    Returns:
        A 2-port model
    """

    @jax.jit
    @validate_call
    def model_2port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        return sax.reciprocal({(p1, p2): jnp.ones_like(wl)})

    return model_2port

model_3port

model_3port(p1: Name, p2: Name, p3: Name) -> SDictModel

Generate a general 3-port model (1x2 splitter).

p2 p1 p3

Parameters:

Name Type Description Default
p1 Name

Name of the input port (typically o1 or in0).

required
p2 Name

Name of the first output port (typically o2 or out1).

required
p3 Name

Name of the second output port (typically o3 or out0).

required

Returns:

Type Description
SDictModel

A 3-port model

Source code in src/sax/models/factories.py
@validate_call
def model_3port(p1: sax.Name, p2: sax.Name, p3: sax.Name) -> sax.SDictModel:
    """Generate a general 3-port model (1x2 splitter).

    ```{svgbob}
            +-----+--* p2
     p1 *---|     |
            +-----+--* p3
    ```

    Args:
        p1: Name of the input port (typically o1 or in0).
        p2: Name of the first output port (typically o2 or out1).
        p3: Name of the second output port (typically o3 or out0).

    Returns:
        A 3-port model
    """

    @jax.jit
    @validate_call
    def model_3port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        thru = jnp.ones_like(wl) / jnp.sqrt(2)
        return sax.reciprocal(
            {
                (p1, p2): thru,
                (p1, p3): thru,
            }
        )

    return model_3port

model_4port

model_4port(p1: Name, p2: Name, p3: Name, p4: Name) -> SDictModel

Generate a general 4-port model (2x2 coupler).

p2 p3 p1 p4

Parameters:

Name Type Description Default
p1 Name

Name of the first input port (typically o1 or in0).

required
p2 Name

Name of the second input port (typically o2 or in1).

required
p3 Name

Name of the first output port (typically o3 or out1).

required
p4 Name

Name of the second output port (typicall o4 or out0).

required

Returns:

Type Description
SDictModel

A 4-port model

Source code in src/sax/models/factories.py
@validate_call
def model_4port(
    p1: sax.Name, p2: sax.Name, p3: sax.Name, p4: sax.Name
) -> sax.SDictModel:
    """Generate a general 4-port model (2x2 coupler).

    ```{svgbob}
    p2 *---+-----+--* p3
           |     |
    p1 *---+-----+--* p4
    ```

    Args:
        p1: Name of the first input port (typically o1 or in0).
        p2: Name of the second input port (typically o2 or in1).
        p3: Name of the first output port (typically o3 or out1).
        p4: Name of the second output port (typicall o4 or out0).

    Returns:
        A 4-port model
    """

    @jax.jit
    @validate_call
    def model_4port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        thru = jnp.ones_like(wl) / jnp.sqrt(2)
        cross = 1j * thru
        return sax.reciprocal(
            {
                (p1, p4): thru,
                (p2, p3): thru,
                (p1, p3): cross,
                (p2, p4): cross,
            }
        )

    return model_4port

passthru

passthru(num_links: int, *, reciprocal: bool = True) -> SCooModel

Copy 100% of the power at each input port to its corresponding output port.

o ni 1 o ni . . . o2 in1 o N 2 o1 in0 o N 1

Parameters:

Name Type Description Default
num_links int

Number of independent pass-through links (input-output pairs). This creates a device with num_links inputs and num_links outputs.

required
reciprocal bool

If True, the device exhibits reciprocal behavior where transmission is identical in both directions. This is typical for passive devices like fibers and waveguides. Defaults to True.

True

Returns:

Type Description
SCooModel

A passthru model

Examples:

Create an 8×8 optical switch (straight-through state):

switch_thru = sax.models.passthru(8, reciprocal=True)
Si, Sj, Sx, port_map = switch_thru(wl=1.55)
# Each input passes straight to corresponding output
Source code in src/sax/models/factories.py
@validate_call
def passthru(
    num_links: int,
    *,
    reciprocal: bool = True,
) -> sax.SCooModel:
    """Copy 100% of the power at each input port to its corresponding output port.

    ```{svgbob}
    o_ni -1 *---+-----+--* o_ni
                |     |
                  .
                  .
                  .
                |     |
         o2 *---+-----+--* o_N -2
        in1     |     |
                |     |
         o1 *---+-----+--* o_N -1
        in0
    ```

    Args:
        num_links: Number of independent pass-through links (input-output pairs).
            This creates a device with num_links inputs and num_links outputs.
        reciprocal: If True, the device exhibits reciprocal behavior where
            transmission is identical in both directions. This is typical for
            passive devices like fibers and waveguides. Defaults to True.

    Returns:
        A passthru model

    Examples:
        Create an 8×8 optical switch (straight-through state):

        ```python
        switch_thru = sax.models.passthru(8, reciprocal=True)
        Si, Sj, Sx, port_map = switch_thru(wl=1.55)
        # Each input passes straight to corresponding output
        ```
    """
    passthru = unitary(
        num_links,
        num_links,
        reciprocal=reciprocal,
        diagonal=True,
    )
    passthru.__name__ = f"passthru_{num_links}_{num_links}"
    passthru.__qualname__ = f"passthru_{num_links}_{num_links}"
    return jax.jit(passthru)

phase_shifter

phase_shifter(
    wl: FloatArrayLike = WL_C,
    neff: FloatArrayLike = 2.34,
    voltage: FloatArrayLike = 0,
    length: FloatArrayLike = 10,
    loss: FloatArrayLike = 0.0,
) -> SDict

Simple voltage-controlled phase shifter model.

in0 o1 out0 o2

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C
neff FloatArrayLike

The Effective index of the unperturbed waveguide mode.

2.34
voltage FloatArrayLike

The Applied voltage in volts. The phase shift is assumed to be linearly proportional to voltage with a coefficient of π rad/V. Positive voltage increases the phase. Defaults to 0 V.

0
length FloatArrayLike

The length of the phase shifter in micrometers.

10
loss FloatArrayLike

Additional loss in dB introduced by the active region.

0.0

Returns:

Type Description
SDict

S-matrix dictionary containing the complex-valued transmission coefficient.

Examples:

Phase shifter:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.phase_shifter(wl=wl, loss=3.0)
thru = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/straight.py
@jax.jit
@validate_call
def phase_shifter(
    wl: sax.FloatArrayLike = sax.WL_C,
    neff: sax.FloatArrayLike = 2.34,
    voltage: sax.FloatArrayLike = 0,
    length: sax.FloatArrayLike = 10,
    loss: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    """Simple voltage-controlled phase shifter model.

    ```{svgbob}
    in0             out0
     o1 =========== o2
    ```

    Args:
        wl: The wavelength in micrometers.
        neff: The Effective index of the unperturbed waveguide mode.
        voltage: The Applied voltage in volts. The phase shift is assumed to be
            linearly proportional to voltage with a coefficient of π rad/V.
            Positive voltage increases the phase. Defaults to 0 V.
        length: The length of the phase shifter in micrometers.
        loss: Additional loss in dB introduced by the active region.

    Returns:
        S-matrix dictionary containing the complex-valued transmission coefficient.

    Examples:
        Phase shifter:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.phase_shifter(wl=wl, loss=3.0)
        thru = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    deltaphi = voltage * jnp.pi
    phase = 2 * jnp.pi * neff * length / wl + deltaphi
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.o1, p.o2): transmission,
        }
    )

splitter_ideal

splitter_ideal(*, wl: FloatArrayLike = WL_C, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 1x2 power splitter model.

out0 o2 in0 o1 out1 o3

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
coupling FloatArrayLike

Power coupling ratio between 0 and 1.

0.5

Returns:

Type Description
SDict

S-matrix dictionary containing the complex-valued cross/thru coefficients.

Examples:

Ideal 1x2 splitter:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.splitter_ideal(wl=wl, coupling=0.3)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/splitters.py
@jax.jit
@validate_call
def splitter_ideal(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    coupling: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Ideal 1x2 power splitter model.

    ```{svgbob}
                    .--------- out0
                   /           o2
                  /   .-------
    in0 ---------'   /
     o1             (
        ---------.   \
                  \   '------- out1
                   \           o3
                    '---------
    ```

    Args:
        wl: Wavelength in micrometers.
        coupling: Power coupling ratio between 0 and 1.

    Returns:
        S-matrix dictionary containing the complex-valued cross/thru coefficients.

    Examples:
        Ideal 1x2 splitter:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.splitter_ideal(wl=wl, coupling=0.3)
        thru = np.abs(s[("o1", "o3")]) ** 2
        cross = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    kappa = jnp.asarray(coupling**0.5) * one
    tau = jnp.asarray((1 - coupling) ** 0.5) * one

    p = sax.PortNamer(num_inputs=1, num_outputs=2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): kappa,
        },
    )

unitary

unitary(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> SCooModel

Generate a unitary N×M optical device model.

o ni 1 o ni . . . o2 in1 o N 2 o1 in0 o N 1

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
reciprocal bool

If True, the device exhibits reciprocal behavior (S = S^T). This is typical for passive optical devices. Defaults to True.

True
diagonal bool

If True, creates a diagonal coupling matrix (each input couples to only one output). If False, creates full coupling between all input-output pairs. Defaults to False.

False

Returns:

Type Description
SCooModel

A unitary model

Source code in src/sax/models/factories.py
@validate_call
def unitary(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> sax.SCooModel:
    """Generate a unitary N×M optical device model.

    ```{svgbob}
    o_ni -1 *---+-----+--* o_ni
                |     |
                  .
                  .
                  .
                |     |
         o2 *---+-----+--* o_N -2
        in1     |     |
                |     |
         o1 *---+-----+--* o_N -1
        in0
    ```

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        reciprocal: If True, the device exhibits reciprocal behavior (S = S^T).
            This is typical for passive optical devices. Defaults to True.
        diagonal: If True, creates a diagonal coupling matrix (each input
            couples to only one output). If False, creates full coupling
            between all input-output pairs. Defaults to False.

    Returns:
        A unitary model
    """
    # let's create the squared S-matrix:
    N = max(num_inputs, num_outputs)
    S = jnp.zeros((2 * N, 2 * N), dtype=float)

    if not diagonal:
        S = S.at[:N, N:].set(1)
    else:
        r = jnp.arange(
            N,
            dtype=int,
        )  # reciprocal only works if num_inputs == num_outputs!
        S = S.at[r, N + r].set(1)

    if reciprocal:
        if not diagonal:
            S = S.at[N:, :N].set(1)
        else:
            r = jnp.arange(
                N,
                dtype=int,
            )  # reciprocal only works if num_inputs == num_outputs!
            S = S.at[N + r, r].set(1)

    # Now we need to normalize the squared S-matrix
    U, s, V = jnp.linalg.svd(S, full_matrices=False)
    S = jnp.sqrt(U @ jnp.diag(jnp.where(s > sax.EPS, 1, 0)) @ V)

    # Now create subset of this matrix we're interested in:
    r = jnp.concatenate(
        [jnp.arange(num_inputs, dtype=int), N + jnp.arange(num_outputs, dtype=int)],
        0,
    )
    S = S[r, :][:, r]

    # let's convert it in SCOO format:
    Si, Sj = jnp.where(S > sax.EPS)
    Sx = S[Si, Sj]

    # the last missing piece is a port map:
    p = sax.PortNamer(num_inputs, num_outputs)
    pm = {
        **{p[i]: i for i in range(num_inputs)},
        **{p[i + num_inputs]: i + num_inputs for i in range(num_outputs)},
    }

    @validate_call
    def func(*, wl: sax.FloatArrayLike = 1.5) -> sax.SCoo:
        wl_ = jnp.asarray(wl)
        Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))
        return Si, Sj, Sx_, pm

    func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
    func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
    return jax.jit(func)