Skip to content

models

SAX Models.

Modules:

Name Description
bends

SAX Bend Models.

couplers

SAX Coupler Models.

crossings

SAX Crossing Models.

factories

SAX Default Models.

isolators

SAX Isolator and Circulator Models.

mmis

SAX MMI Models.

probes

SAX Probe Models.

reflectors

SAX Reflector Models.

rf

Sax generic RF models.

splitters

SAX Default Models.

straight

SAX Default Models.

terminators

SAX Terminator Models.

Functions:

Name Description
admittance

Generalized two-port admittance element.

attenuator

Simple optical attenuator model.

bend

Simple waveguide bend model.

capacitor

Ideal two-port capacitor model.

circulator

Optical circulator model (non-reciprocal 3-port device).

copier

Copy 100% of the power at the input ports to all output ports.

coplanar_waveguide

S-parameter model for a straight coplanar waveguide.

coupler

Dispersive directional coupler model.

coupler_ideal

Ideal 2x2 directional coupler model.

cpw_epsilon_eff

Effective permittivity of a CPW on a finite-height substrate.

cpw_thickness_correction

Apply conductor thickness correction to CPW ε_eff and Z₀.

cpw_z0

Characteristic impedance of a CPW.

crossing_ideal

Ideal waveguide crossing model.

electrical_open

Electrical open connection Sax model.

electrical_short

Electrical short connection Sax model.

ellipk_ratio

Ratio of complete elliptic integrals of the first kind K(m) / K(1-m).

gamma_0_load

Connection with given reflection coefficient.

grating_coupler

Grating coupler model for fiber-chip coupling.

ideal_probe

Ideal 4-port measurement probe with 100% transmission and 100% tap coupling.

impedance

Generalized two-port impedance element.

inductor

Ideal two-port inductor model.

isolator

Optical isolator model (non-reciprocal).

lc_shunt_component

SAX component for a 1-port shunted LC resonator.

microstrip

S-parameter model for a straight microstrip transmission line.

microstrip_epsilon_eff

Effective permittivity of a microstrip line.

microstrip_thickness_correction

Conductor thickness correction for a microstrip line.

microstrip_z0

Characteristic impedance of a microstrip line.

mirror

Ideal mirror model (reflector with default 100% reflection).

mmi1x2

Realistic 1x2 MMI splitter model with dispersion and loss.

mmi1x2_ideal

Ideal 1x2 multimode interference (MMI) splitter model.

mmi2x2

Realistic 2x2 MMI coupler model with dispersion and asymmetry.

mmi2x2_ideal

Ideal 2x2 multimode interference (MMI) coupler model.

model_2port

Generate a general 2-port model with 100% transmission.

model_3port

Generate a general 3-port model (1x2 splitter).

model_4port

Generate a general 4-port model (2x2 coupler).

passthru

Copy 100% of the power at each input port to its corresponding output port.

phase_shifter

Simple voltage-controlled phase shifter model.

propagation_constant

Complex propagation constant of a quasi-TEM transmission line.

reflector

Partial reflector / mirror model.

resistor

Ideal two-port resistor model.

splitter_ideal

Ideal 1x2 power splitter model.

tee

Ideal three-port RF power divider/combiner (T-junction).

terminator

Optical terminator / absorber model.

transmission_line_s_params

S-parameters of a uniform transmission line (ABCD→S conversion).

unitary

Generate a unitary N×M optical device model.

admittance

admittance(*, f: FloatArrayLike = DEFAULT_FREQUENCY, y: ComplexLike = 1 / 50) -> SDict

Generalized two-port admittance element.

Parameters:

Name Type Description Default
f FloatArrayLike

Frequency in Hz

DEFAULT_FREQUENCY
y ComplexLike

Admittance in siemens

1 / 50

Returns:

Type Description
SDict

S-dictionary representing the admittance element

References

Pozar

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
import jaxellip  # pyright: ignore[reportMissingImports]
import scipy.constants


sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
s = sax.models.rf.admittance(f=f, y=1 / 75)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Magnitude")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def admittance(
    *, f: sax.FloatArrayLike = DEFAULT_FREQUENCY, y: sax.ComplexLike = 1 / 50
) -> sax.SDict:
    r"""Generalized two-port admittance element.

    Args:
        f: Frequency in Hz
        y: Admittance in siemens

    Returns:
        S-dictionary representing the admittance element

    References:
        Pozar

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax
        import jaxellip  # pyright: ignore[reportMissingImports]
        import scipy.constants


        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        s = sax.models.rf.admittance(f=f, y=1 / 75)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Magnitude")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(f))
    sdict = {
        ("o1", "o1"): 1 / (1 + y) * one,
        ("o1", "o2"): y / (1 + y) * one,
        ("o2", "o2"): 1 / (1 + y) * one,
    }
    return sax.reciprocal(sdict)

attenuator

attenuator(*, wl: FloatArrayLike = WL_C, loss: FloatArrayLike = 0.0) -> SDict

Simple optical attenuator model.

in0 o1 out0 o2

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C
loss FloatArrayLike

Attenuation in decibels (dB).

0.0

Returns:

Type Description
SDict

S-matrix dictionary containing the complex transmission coefficient.

Examples:

Attenuator:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.attenuator(wl=wl, loss=3.0)
thru = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/straight.py
@jax.jit
@validate_call
def attenuator(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    loss: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    """Simple optical attenuator model.

    ```{svgbob}
    in0             out0
     o1 =========== o2
    ```

    Args:
        wl: The wavelength in micrometers.
        loss: Attenuation in decibels (dB).

    Returns:
        S-matrix dictionary containing the complex transmission coefficient.

    Examples:
        Attenuator:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.attenuator(wl=wl, loss=3.0)
        thru = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    transmission = jnp.asarray(10 ** (-loss / 20), dtype=complex) * one
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.out0): transmission,
        }
    )

bend

bend(
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    neff: FloatArrayLike = 2.34,
    ng: FloatArrayLike = 3.4,
    length: FloatArrayLike = 10.0,
    loss_dB_cm: FloatArrayLike = 0.1,
) -> SDict

Simple waveguide bend model.

out0 o2 o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

WL_C
wl0 FloatArrayLike

Reference wavelength in micrometers used for dispersion calculation. This is typically the design wavelength where neff is specified. Defaults to 1.55 μm.

WL_C
neff FloatArrayLike

Effective refractive index at the reference wavelength. This value represents the fundamental mode effective index and determines the phase velocity. Defaults to 2.34 (typical for silicon).

2.34
ng FloatArrayLike

Group refractive index at the reference wavelength. Used to calculate chromatic dispersion: ng = neff - lambda * d(neff)/d(lambda). Typically ng > neff for normal dispersion. Defaults to 3.4.

3.4
length FloatArrayLike

Physical length of the waveguide in micrometers. Determines both the total phase accumulation and loss. Defaults to 10.0 μm.

10.0
loss_dB_cm FloatArrayLike

Propagation loss in dB/cm. Includes material absorption, scattering losses, and other loss mechanisms. Set to 0.0 for lossless modeling. Defaults to 0.0 dB/cm.

0.1

Returns:

Type Description
SDict

The bend s-matrix

Examples:

Basic bend simulation:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.bend(wl=wl, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5)
thru = np.abs(s[("o1", "o2")]) ** 2

plt.plot(wl, thru)
plt.xlabel("Wavelength (μm)")
plt.ylabel("thru")
Note

This model treats the bend as an equivalent straight waveguide and does not account for:

  • Mode coupling between bend eigenmodes
  • Radiation losses due to bending
  • Polarization effects in bends
  • Non-linear dispersion effects

For more accurate bend modeling, consider using dedicated bend models that account for these physical effects.

Source code in src/sax/models/bends.py
@jax.jit
@validate_call
def bend(
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    neff: sax.FloatArrayLike = 2.34,
    ng: sax.FloatArrayLike = 3.4,
    length: sax.FloatArrayLike = 10.0,
    loss_dB_cm: sax.FloatArrayLike = 0.1,
) -> sax.SDict:
    """Simple waveguide bend model.

    ```{svgbob}
              out0
              o2

             /
         __.'
    o1
    in0
    ```

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Reference wavelength in micrometers used for dispersion calculation.
            This is typically the design wavelength where neff is specified.
            Defaults to 1.55 μm.
        neff: Effective refractive index at the reference wavelength. This value
            represents the fundamental mode effective index and determines the
            phase velocity. Defaults to 2.34 (typical for silicon).
        ng: Group refractive index at the reference wavelength. Used to calculate
            chromatic dispersion: ng = neff - lambda * d(neff)/d(lambda).
            Typically ng > neff for normal dispersion. Defaults to 3.4.
        length: Physical length of the waveguide in micrometers. Determines both
            the total phase accumulation and loss. Defaults to 10.0 μm.
        loss_dB_cm: Propagation loss in dB/cm. Includes material absorption,
            scattering losses, and other loss mechanisms. Set to 0.0 for
            lossless modeling. Defaults to 0.0 dB/cm.

    Returns:
        The bend s-matrix

    Examples:
        Basic bend simulation:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax
    sax.set_port_naming_strategy("optical")

    wl = sax.wl_c()
    s = sax.models.bend(wl=wl, length=50.0, loss_dB_cm=0.1, neff=2.35, ng=3.5)
    thru = np.abs(s[("o1", "o2")]) ** 2

    plt.plot(wl, thru)
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("thru")
    ```

    Note:
        This model treats the bend as an equivalent straight waveguide and does not
        account for:

        - Mode coupling between bend eigenmodes
        - Radiation losses due to bending
        - Polarization effects in bends
        - Non-linear dispersion effects

        For more accurate bend modeling, consider using dedicated bend models that
        account for these physical effects.
    """
    return straight(
        wl=wl,
        wl0=wl0,
        neff=neff,
        ng=ng,
        length=length,
        loss_dB_cm=loss_dB_cm,
    )

capacitor

capacitor(
    *,
    f: FloatArrayLike = DEFAULT_FREQUENCY,
    capacitance: FloatLike = 1e-15,
    z0: ComplexLike = 50,
) -> SDict

Ideal two-port capacitor model.

Parameters:

Name Type Description Default
f FloatArrayLike

Frequency in Hz

DEFAULT_FREQUENCY
capacitance FloatLike

Capacitance in Farads

1e-15
z0 ComplexLike

Reference impedance in Ω.

50

Returns:

Type Description
SDict

S-dictionary representing the capacitor element

References

Pozar

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
import jaxellip  # pyright: ignore[reportMissingImports]
import scipy.constants


sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
s = sax.models.rf.capacitor(f=f, capacitance=1e-12, z0=50)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Magnitude")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def capacitor(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    capacitance: sax.FloatLike = 1e-15,
    z0: sax.ComplexLike = 50,
) -> sax.SDict:
    r"""Ideal two-port capacitor model.

    Args:
        f: Frequency in Hz
        capacitance: Capacitance in Farads
        z0: Reference impedance in Ω.

    Returns:
        S-dictionary representing the capacitor element

    References:
        Pozar

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax
        import jaxellip  # pyright: ignore[reportMissingImports]
        import scipy.constants


        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        s = sax.models.rf.capacitor(f=f, capacitance=1e-12, z0=50)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Magnitude")
        plt.legend()
        ```
    """
    angular_frequency = 2 * jnp.pi * jnp.asarray(f)
    capacitor_impedance = 1 / (1j * angular_frequency * capacitance)
    return impedance(f=f, z=capacitor_impedance, z0=z0)

circulator

circulator(
    *,
    wl: FloatArrayLike = WL_C,
    insertion_loss_dB: FloatArrayLike = 0.0,
    isolation_dB: FloatArrayLike = 40.0,
) -> SDict

Optical circulator model (non-reciprocal 3-port device).

Routes light in a circular fashion: port 1 -> port 2 -> port 3 -> port 1. Light traveling in the reverse direction is strongly attenuated.

o2 o1 o3

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
insertion_loss_dB FloatArrayLike

Insertion loss in dB for the forward circulation path. Defaults to 0.0 dB.

0.0
isolation_dB FloatArrayLike

Isolation in dB between non-adjacent ports (reverse direction). Defaults to 40.0 dB.

40.0

Returns:

Type Description
SDict

S-matrix dictionary for the circulator.

Examples:

3-port circulator:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.circulator(
    wl=wl,
    insertion_loss_dB=0.5,
    isolation_dB=30.0,
)
fwd_12 = np.abs(s[("o1", "o2")]) ** 2
fwd_23 = np.abs(s[("o2", "o3")]) ** 2
iso_21 = np.abs(s[("o2", "o1")]) ** 2
plt.figure()
plt.plot(wl, 10 * np.log10(fwd_12), label="o1→o2")
plt.plot(wl, 10 * np.log10(fwd_23), label="o2→o3")
plt.plot(wl, 10 * np.log10(iso_21 + 1e-10), label="o2→o1 (isolated)")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Transmission [dB]")
plt.legend()
Source code in src/sax/models/isolators.py
@jax.jit
@validate_call
def circulator(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    insertion_loss_dB: sax.FloatArrayLike = 0.0,
    isolation_dB: sax.FloatArrayLike = 40.0,
) -> sax.SDict:
    r"""Optical circulator model (non-reciprocal 3-port device).

    Routes light in a circular fashion: port 1 -> port 2 -> port 3 -> port 1.
    Light traveling in the reverse direction is strongly attenuated.

    ```{svgbob}
          o2
          *
         / \\
        /   \\
       / --> \\
      *-------*
     o1       o3
    ```

    Args:
        wl: Wavelength in micrometers.
        insertion_loss_dB: Insertion loss in dB for the forward circulation
            path. Defaults to 0.0 dB.
        isolation_dB: Isolation in dB between non-adjacent ports (reverse
            direction). Defaults to 40.0 dB.

    Returns:
        S-matrix dictionary for the circulator.

    Examples:
        3-port circulator:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.circulator(
            wl=wl,
            insertion_loss_dB=0.5,
            isolation_dB=30.0,
        )
        fwd_12 = np.abs(s[("o1", "o2")]) ** 2
        fwd_23 = np.abs(s[("o2", "o3")]) ** 2
        iso_21 = np.abs(s[("o2", "o1")]) ** 2
        plt.figure()
        plt.plot(wl, 10 * np.log10(fwd_12), label="o1→o2")
        plt.plot(wl, 10 * np.log10(fwd_23), label="o2→o3")
        plt.plot(wl, 10 * np.log10(iso_21 + 1e-10), label="o2→o1 (isolated)")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Transmission [dB]")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    forward = jnp.asarray(10 ** (-insertion_loss_dB / 20), dtype=complex) * one
    isolated = jnp.asarray(10 ** (-isolation_dB / 20), dtype=complex) * one
    return {
        # Forward circulation: o1 -> o2 -> o3 -> o1
        ("o1", "o2"): forward,
        ("o2", "o3"): forward,
        ("o3", "o1"): forward,
        # Reverse (isolated): o2 -> o1, o3 -> o2, o1 -> o3
        ("o2", "o1"): isolated,
        ("o3", "o2"): isolated,
        ("o1", "o3"): isolated,
    }

copier

copier(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> SCooModel

Copy 100% of the power at the input ports to all output ports.

o ni 1 o ni . . . o2 in1 o N 2 o1 in0 o N 1

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
reciprocal bool

If True, the device exhibits reciprocal behavior where forward and reverse transmissions are equal. Defaults to True.

True
diagonal bool

If True, creates diagonal coupling (each input couples to only one output). If False, creates full coupling between all input-output pairs. Defaults to False.

False

Returns:

Type Description
SCooModel

A copier model

Examples:

A 1×4 optical amplifier/splitter:

import sax

amp_splitter = sax.models.copier(1, 4, reciprocal=False)
Si, Sj, Sx, port_map = amp_splitter(wl=1.55)
# Single input amplified and split to 4 outputs
Source code in src/sax/models/factories.py
@validate_call
def copier(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> sax.SCooModel:
    """Copy 100% of the power at the input ports to all output ports.

    ```{svgbob}
    o_ni -1 *---+-----+--* o_ni
                |     |
                  .
                  .
                  .
                |     |
         o2 *---+-----+--* o_N -2
        in1     |     |
                |     |
         o1 *---+-----+--* o_N -1
        in0
    ```

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        reciprocal: If True, the device exhibits reciprocal behavior where
            forward and reverse transmissions are equal. Defaults to True.
        diagonal: If True, creates diagonal coupling (each input couples to
            only one output). If False, creates full coupling between all
            input-output pairs. Defaults to False.

    Returns:
        A copier model

    Examples:
        A 1×4 optical amplifier/splitter:

        ```python
        import sax

        amp_splitter = sax.models.copier(1, 4, reciprocal=False)
        Si, Sj, Sx, port_map = amp_splitter(wl=1.55)
        # Single input amplified and split to 4 outputs
        ```

    """
    # let's create the squared S-matrix:
    S = jnp.zeros((num_inputs + num_outputs, num_inputs + num_outputs), dtype=float)

    if not diagonal:
        S = S.at[:num_inputs, num_inputs:].set(1)
    else:
        r = jnp.arange(
            num_inputs,
            dtype=int,
        )  # == range(num_outputs) # reciprocal only works if num_inputs == num_outputs!
        S = S.at[r, num_inputs + r].set(1)

    if reciprocal:
        if not diagonal:
            S = S.at[num_inputs:, :num_inputs].set(1)
        else:
            # reciprocal only works if num_inputs == num_outputs!
            r = jnp.arange(num_inputs, dtype=int)  # == range(num_outputs)
            S = S.at[num_inputs + r, r].set(1)

    S = S.T  # convert from (from, to) to standard (to, from) convention

    # let's convert it in SCOO format:
    Si, Sj = jnp.where(jnp.sqrt(sax.EPS) < S)
    Sx = S[Si, Sj]

    # the last missing piece is a port map:
    p = sax.PortNamer(num_inputs, num_outputs)
    pm = {
        **{p[i]: i for i in range(num_inputs)},
        **{p[i + num_inputs]: i + num_inputs for i in range(num_outputs)},
    }

    @validate_call
    def func(wl: sax.FloatArrayLike = 1.5) -> sax.SCoo:
        wl_ = jnp.asarray(wl)
        Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))
        return Si, Sj, Sx_, pm

    func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
    func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
    return jax.jit(func)

coplanar_waveguide

coplanar_waveguide(
    f: FloatArrayLike = DEFAULT_FREQUENCY,
    length: FloatLike = 1000.0,
    width: FloatLike = 10.0,
    gap: FloatLike = 5.0,
    thickness: FloatLike = 0.0,
    substrate_thickness: FloatLike = 500.0,
    ep_r: FloatLike = 11.45,
    tand: FloatLike = 0.0,
) -> SDict

S-parameter model for a straight coplanar waveguide.

Computes S-parameters analytically using conformal-mapping CPW theory following Simons and the Qucs-S CPW model. Conductor thickness corrections use the first-order model of Gupta, Garg, Bahl, and Bhartia.

References

Simons, ch. 2; Gupta, Garg, Bahl & Bhartia; Qucs technical documentation, §12.4

Parameters:

Name Type Description Default
f FloatArrayLike

Array of frequency points in Hz

DEFAULT_FREQUENCY
length FloatLike

Physical length in µm

1000.0
width FloatLike

Centre-conductor width in µm

10.0
gap FloatLike

Gap between centre conductor and ground plane in µm

5.0
thickness FloatLike

Conductor thickness in µm

0.0
substrate_thickness FloatLike

Substrate height in µm

500.0
ep_r FloatLike

Relative permittivity of the substrate

11.45
tand FloatLike

Dielectric loss tangent

0.0

Returns:

Type Description
SDict

sax.SDict: S-parameters dictionary

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def coplanar_waveguide(
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    length: sax.FloatLike = 1000.0,
    width: sax.FloatLike = 10.0,
    gap: sax.FloatLike = 5.0,
    thickness: sax.FloatLike = 0.0,
    substrate_thickness: sax.FloatLike = 500.0,
    ep_r: sax.FloatLike = 11.45,
    tand: sax.FloatLike = 0.0,
) -> sax.SDict:
    r"""S-parameter model for a straight coplanar waveguide.

    Computes S-parameters analytically using conformal-mapping CPW theory
    following Simons and the Qucs-S CPW model.
    Conductor thickness corrections use the first-order model of
    Gupta, Garg, Bahl, and Bhartia.

    References:
       Simons, ch. 2;
       Gupta, Garg, Bahl & Bhartia;
       Qucs technical documentation, §12.4

    Args:
        f: Array of frequency points in Hz
        length: Physical length in µm
        width: Centre-conductor width in µm
        gap: Gap between centre conductor and ground plane in µm
        thickness: Conductor thickness in µm
        substrate_thickness: Substrate height in µm
        ep_r: Relative permittivity of the substrate
        tand: Dielectric loss tangent

    Returns:
        sax.SDict: S-parameters dictionary
    """
    f = jnp.asarray(f)
    f_flat = f.ravel()

    w_m = width * 1e-6
    s_m = gap * 1e-6
    t_m = thickness * 1e-6
    h_m = substrate_thickness * 1e-6
    length_m = jnp.asarray(length) * 1e-6

    ep_eff = cpw_epsilon_eff(w_m, s_m, h_m, ep_r)

    # Thickness is a scalar, so we use jnp.where to conditionally apply corrections.
    # Using jnp.where is safe:
    ep_eff_t, z0_val = cpw_thickness_correction(w_m, s_m, t_m, ep_eff)

    gamma = propagation_constant(f_flat, ep_eff_t, tand=tand, ep_r=ep_r)
    s11, s21 = transmission_line_s_params(gamma, z0_val, length_m)

    sdict: sax.SDict = {
        ("o1", "o1"): s11.reshape(f.shape),
        ("o1", "o2"): s21.reshape(f.shape),
        ("o2", "o2"): s11.reshape(f.shape),
    }
    return sax.reciprocal(sdict)

coupler

coupler(
    *,
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    length: FloatArrayLike = 0.0,
    coupling0: FloatArrayLike = 0.2,
    dk1: FloatArrayLike = 1.2435,
    dk2: FloatArrayLike = 5.3022,
    dn: FloatArrayLike = 0.02,
    dn1: FloatArrayLike = 0.1169,
    dn2: FloatArrayLike = 0.4821,
) -> SDict

Dispersive directional coupler model.

in1 out1 o2 o3 2 coupling coupling0 length o1 o4 in0 out0 coupling0 2

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for multi-wavelength simulations. Defaults to 1.55 μm.

WL_C
wl0 FloatArrayLike

Reference wavelength in micrometers for dispersion expansion. Typically the center design wavelength. Defaults to 1.55 μm.

WL_C
length FloatArrayLike

Coupling length in micrometers. This is the length over which the two waveguides are in close proximity. Defaults to 0.0 μm.

0.0
coupling0 FloatArrayLike

Base coupling coefficient at the reference wavelength from bend regions or other coupling mechanisms. Obtained from FDTD simulations or measurements. Defaults to 0.2.

0.2
dk1 FloatArrayLike

First-order derivative of coupling coefficient with respect to wavelength (∂κ/∂λ). Units: μm⁻¹. Defaults to 1.2435.

1.2435
dk2 FloatArrayLike

Second-order derivative of coupling coefficient with respect to wavelength (∂²κ/∂λ²). Units: μm⁻². Defaults to 5.3022.

5.3022
dn FloatArrayLike

Effective index difference between even and odd supermodes at the reference wavelength. Determines beating length. Defaults to 0.02.

0.02
dn1 FloatArrayLike

First-order derivative of effective index difference with respect to wavelength (∂Δn/∂λ). Units: μm⁻¹. Defaults to 0.1169.

0.1169
dn2 FloatArrayLike

Second-order derivative of effective index difference with respect to wavelength (∂²Δn/∂λ²). Units: μm⁻². Defaults to 0.4821.

0.4821

Returns:

Type Description
SDict

The coupler s-matrix

Examples:

Basic dispersive coupler:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.coupler(
    wl=wl,
    length=15.7,
    coupling0=0.0,
    dn=0.02,
)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Note

The coupling strength follows the formula:

κ_total = κ₀ + κ₁ + κ_length

Where:

  • κ₀ = coupling0 + dk1(λ-λ₀) + 0.5dk2*(λ-λ₀)²
  • κ₁ = π*Δn(λ)/λ
  • κ_length represents the distributed coupling over the interaction length
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def coupler(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    length: sax.FloatArrayLike = 0.0,
    coupling0: sax.FloatArrayLike = 0.2,
    dk1: sax.FloatArrayLike = 1.2435,
    dk2: sax.FloatArrayLike = 5.3022,
    dn: sax.FloatArrayLike = 0.02,
    dn1: sax.FloatArrayLike = 0.1169,
    dn2: sax.FloatArrayLike = 0.4821,
) -> sax.SDict:
    r"""Dispersive directional coupler model.

    ```{svgbob}
            in1                out1
             o2                o3
              *                *
               \              /
                '------------'
    coupling0 /2   coupling   coupling0 /2
                .------------.
               /  <-length -> \
              *                *
             o1                o4
            in0                out0
    ```

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            multi-wavelength simulations. Defaults to 1.55 μm.
        wl0: Reference wavelength in micrometers for dispersion expansion.
            Typically the center design wavelength. Defaults to 1.55 μm.
        length: Coupling length in micrometers. This is the length over which
            the two waveguides are in close proximity. Defaults to 0.0 μm.
        coupling0: Base coupling coefficient at the reference wavelength from
            bend regions or other coupling mechanisms. Obtained from FDTD
            simulations or measurements. Defaults to 0.2.
        dk1: First-order derivative of coupling coefficient with respect to
            wavelength (∂κ/∂λ). Units: μm⁻¹. Defaults to 1.2435.
        dk2: Second-order derivative of coupling coefficient with respect to
            wavelength (∂²κ/∂λ²). Units: μm⁻². Defaults to 5.3022.
        dn: Effective index difference between even and odd supermodes at
            the reference wavelength. Determines beating length. Defaults to 0.02.
        dn1: First-order derivative of effective index difference with respect
            to wavelength (∂Δn/∂λ). Units: μm⁻¹. Defaults to 0.1169.
        dn2: Second-order derivative of effective index difference with respect
            to wavelength (∂²Δn/∂λ²). Units: μm⁻². Defaults to 0.4821.

    Returns:
        The coupler s-matrix

    Examples:
        Basic dispersive coupler:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wl = sax.wl_c()
    s = sax.models.coupler(
        wl=wl,
        length=15.7,
        coupling0=0.0,
        dn=0.02,
    )
    thru = np.abs(s[("o1", "o4")]) ** 2
    cross = np.abs(s[("o1", "o3")]) ** 2
    plt.figure()
    plt.plot(wl, thru, label="thru")
    plt.plot(wl, cross, label="cross")
    plt.xlabel("Wavelength [μm]")
    plt.ylabel("Power")
    plt.legend()
    ```

    Note:
        The coupling strength follows the formula:

        κ_total = κ₀ + κ₁ + κ_length

        Where:

        - κ₀ = coupling0 + dk1*(λ-λ₀) + 0.5*dk2*(λ-λ₀)²
        - κ₁ = π*Δn(λ)/λ
        - κ_length represents the distributed coupling over the interaction length

    """
    dwl = wl - wl0
    dn = dn + dn1 * dwl + 0.5 * dn2 * dwl**2
    kappa0 = coupling0 + dk1 * dwl + 0.5 * dk2 * dwl**2
    kappa1 = jnp.pi * dn / wl

    tau = jnp.cos(kappa0 + kappa1 * length)
    kappa = -jnp.sin(kappa0 + kappa1 * length)
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): 1j * kappa,
            (p.in1, p.out0): 1j * kappa,
            (p.in1, p.out1): tau,
        }
    )

coupler_ideal

coupler_ideal(*, wl: FloatArrayLike = WL_C, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 2x2 directional coupler model.

in1 out1 o2 o3 coupling o1 o4 in0 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

the wavelength of the simulation in micrometers.

WL_C
coupling FloatArrayLike

Power coupling coefficient between 0 and 1. Defaults to 0.5.

0.5

Returns:

Type Description
SDict

The coupler s-matrix

Examples:

Ideal coupler:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.coupler_ideal(
    wl=wl,
    coupling=0.3,
)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def coupler_ideal(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    coupling: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Ideal 2x2 directional coupler model.

    ```{svgbob}
     in1          out1
      o2          o3
       *          *
        \        /
         '------'
         coupling
         .------.
        /        \
       *          *
      o1          o4
     in0          out0
    ```

    Args:
        wl: the wavelength of the simulation in micrometers.
        coupling: Power coupling coefficient between 0 and 1. Defaults to 0.5.

    Returns:
        The coupler s-matrix

    Examples:
        Ideal coupler:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.coupler_ideal(
            wl=wl,
            coupling=0.3,
        )
        thru = np.abs(s[("o1", "o4")]) ** 2
        cross = np.abs(s[("o1", "o3")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(wl)
    kappa = jnp.asarray(coupling**0.5) * one
    tau = jnp.asarray((1 - coupling) ** 0.5) * one
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): 1j * kappa,
            (p.in1, p.out0): 1j * kappa,
            (p.in1, p.out1): tau,
        },
    )

cpw_epsilon_eff

cpw_epsilon_eff(
    w: FloatArrayLike, s: FloatArrayLike, h: FloatArrayLike, ep_r: FloatArrayLike
) -> Array

Effective permittivity of a CPW on a finite-height substrate.

\[ \begin{aligned} k_0 &= \frac{w}{w + 2s} \\ k_1 &= \frac{\sinh(\pi w / 4h)}{\sinh\bigl(\pi(w + 2s) / 4h\bigr)} \\ q_1 &= \frac{K(k_1^2)/K(1 - k_1^2)}{K(k_0^2)/K(1 - k_0^2)} \\ \varepsilon_{\mathrm{eff}} &= 1 + \frac{q_1(\varepsilon_r - 1)}{2} \end{aligned} \]

where \(K\) is the complete elliptic integral of the first kind in the parameter convention (\(m = k^2\)).

References

Simoons, Eq. 2.37; Ghione & Naldi

Parameters:

Name Type Description Default
w FloatArrayLike

Centre-conductor width (m).

required
s FloatArrayLike

Gap to ground plane (m).

required
h FloatArrayLike

Substrate height (m).

required
ep_r FloatArrayLike

Relative permittivity of the substrate.

required

Returns:

Type Description
Array

Effective permittivity (dimensionless).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def cpw_epsilon_eff(
    w: sax.FloatArrayLike,
    s: sax.FloatArrayLike,
    h: sax.FloatArrayLike,
    ep_r: sax.FloatArrayLike,
) -> jax.Array:
    r"""Effective permittivity of a CPW on a finite-height substrate.

    $$
    \begin{aligned}
        k_0 &= \frac{w}{w + 2s} \\
        k_1 &= \frac{\sinh(\pi w / 4h)}{\sinh\bigl(\pi(w + 2s) / 4h\bigr)} \\
        q_1 &= \frac{K(k_1^2)/K(1 - k_1^2)}{K(k_0^2)/K(1 - k_0^2)}  \\
        \varepsilon_{\mathrm{eff}} &= 1 + \frac{q_1(\varepsilon_r - 1)}{2}
    \end{aligned}
    $$

    where $K$ is the complete elliptic integral of the first kind in
    the *parameter* convention ($m = k^2$).

    References:
        Simoons, Eq. 2.37;
        Ghione & Naldi

    Args:
        w: Centre-conductor width (m).
        s: Gap to ground plane (m).
        h: Substrate height (m).
        ep_r: Relative permittivity of the substrate.

    Returns:
        Effective permittivity (dimensionless).
    """
    w = jnp.asarray(w, dtype=float)
    s = jnp.asarray(s, dtype=float)
    h = jnp.asarray(h, dtype=float)
    ep_r = jnp.asarray(ep_r, dtype=float)
    k0 = w / (w + 2.0 * s)
    k1 = jnp.sinh(jnp.pi * w / (4.0 * h)) / jnp.sinh(jnp.pi * (w + 2.0 * s) / (4.0 * h))
    q1 = ellipk_ratio(k1**2) / ellipk_ratio(k0**2)
    return 1.0 + q1 * (ep_r - 1.0) / 2.0

cpw_thickness_correction

cpw_thickness_correction(
    w: FloatArrayLike, s: FloatArrayLike, t: FloatArrayLike, ep_eff: FloatArrayLike
) -> tuple[Any, Any]

Apply conductor thickness correction to CPW ε_eff and Z₀.

First-order correction from Gupta, Garg, Bahl & Bhartia

\[ \begin{aligned} \Delta &= \frac{1.25\,t}{\pi} \left(1 + \ln\\frac{4\pi w}{t}\right) \\ k_e &= k_0 + (1 - k_0^2)\,\frac{\Delta}{2s} \\ \varepsilon_{\mathrm{eff},t} &= \varepsilon_{\mathrm{eff}} - \frac{0.7\,(\varepsilon_{\mathrm{eff}} - 1)\,t/s} {K(k_0^2)/K(1-k_0^2) + 0.7\,t/s} \\ Z_{0,t} &= \frac{30\pi} {\sqrt{\varepsilon_{\mathrm{eff},t}}\; K(k_e^2)/K(1-k_e^2)} \end{aligned} \]
References

Gupta, Garg, Bahl & Bhartia, §7.3, Eqs. 7.98-7.100

Parameters:

Name Type Description Default
w FloatArrayLike

Centre-conductor width (m).

required
s FloatArrayLike

Gap to ground plane (m).

required
t FloatArrayLike

Conductor thickness (m).

required
ep_eff FloatArrayLike

Uncorrected effective permittivity.

required

Returns:

Type Description
Any

(ep_eff_t, z0_t) — thickness-corrected effective permittivity

Any

and characteristic impedance (Ω).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def cpw_thickness_correction(
    w: sax.FloatArrayLike,
    s: sax.FloatArrayLike,
    t: sax.FloatArrayLike,
    ep_eff: sax.FloatArrayLike,
) -> tuple[typing.Any, typing.Any]:
    r"""Apply conductor thickness correction to CPW ε_eff and Z₀.

    First-order correction from Gupta, Garg, Bahl & Bhartia


    $$
    \begin{aligned}
        \Delta &= \frac{1.25\,t}{\pi}
    \left(1 + \ln\\frac{4\pi w}{t}\right) \\
    k_e    &= k_0 + (1 - k_0^2)\,\frac{\Delta}{2s} \\
    \varepsilon_{\mathrm{eff},t}
    &= \varepsilon_{\mathrm{eff}}
    - \frac{0.7\,(\varepsilon_{\mathrm{eff}} - 1)\,t/s}
    {K(k_0^2)/K(1-k_0^2) + 0.7\,t/s} \\
    Z_{0,t} &= \frac{30\pi}
    {\sqrt{\varepsilon_{\mathrm{eff},t}}\;
    K(k_e^2)/K(1-k_e^2)}
    \end{aligned}
    $$

    References:
        Gupta, Garg, Bahl & Bhartia, §7.3, Eqs. 7.98-7.100


    Args:
        w: Centre-conductor width (m).
        s: Gap to ground plane (m).
        t: Conductor thickness (m).
        ep_eff: Uncorrected effective permittivity.

    Returns:
        ``(ep_eff_t, z0_t)`` — thickness-corrected effective permittivity
        and characteristic impedance (Ω).
    """
    w = jnp.asarray(w, dtype=float)
    s = jnp.asarray(s, dtype=float)
    t = jnp.asarray(t, dtype=float)
    ep_eff = jnp.asarray(ep_eff, dtype=float)

    k0 = w / (w + 2.0 * s)
    q0 = ellipk_ratio(k0**2)

    t_safe = jnp.where(t < 1e-15, 1e-15, t)
    delta = (1.25 * t / jnp.pi) * (1.0 + jnp.log(4.0 * jnp.pi * w / t_safe))

    ke = k0 + (1.0 - k0**2) * delta / (2.0 * s)
    ke = jnp.clip(ke, 1e-12, 1.0 - 1e-12)

    ep_eff_t = ep_eff - (0.7 * (ep_eff - 1.0) * t / s) / (q0 + 0.7 * t / s)
    z0_t = 30.0 * jnp.pi / (jnp.sqrt(ep_eff_t) * ellipk_ratio(ke**2))

    ep_eff_t = jnp.where(t <= 0, ep_eff, ep_eff_t)
    z0_t = jnp.where(t <= 0, cpw_z0(w, s, ep_eff), z0_t)

    return ep_eff_t, z0_t

cpw_z0

Characteristic impedance of a CPW.

\[ Z_0 = \frac{30\pi}{\sqrt{\varepsilon_{\mathrm{eff}}} K(k_0^2)/K(1 - k_0^2)} \]
References

Simons, Eq. 2.38. Note that our \(w\) and \(s\) correspond to Simons' \(s\) and \(w\).

Parameters:

Name Type Description Default
w FloatArrayLike

Centre-conductor width (m).

required
s FloatArrayLike

Gap to ground plane (m).

required
ep_eff FloatArrayLike

Effective permittivity (see :func:cpw_epsilon_eff).

required

Returns:

Type Description
Array

Characteristic impedance (Ω).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def cpw_z0(
    w: sax.FloatArrayLike,
    s: sax.FloatArrayLike,
    ep_eff: sax.FloatArrayLike,
) -> jax.Array:
    r"""Characteristic impedance of a CPW.

    $$
    Z_0 = \frac{30\pi}{\sqrt{\varepsilon_{\mathrm{eff}}} K(k_0^2)/K(1 - k_0^2)}
    $$


    References:
        Simons, Eq. 2.38.
        Note that our $w$ and $s$ correspond to Simons' $s$ and $w$.

    Args:
        w: Centre-conductor width (m).
        s: Gap to ground plane (m).
        ep_eff: Effective permittivity (see :func:`cpw_epsilon_eff`).

    Returns:
        Characteristic impedance (Ω).
    """
    w = jnp.asarray(w, dtype=float)
    s = jnp.asarray(s, dtype=float)
    ep_eff = jnp.asarray(ep_eff, dtype=float)
    k0 = w / (w + 2.0 * s)
    return 30.0 * jnp.pi / (jnp.sqrt(ep_eff) * ellipk_ratio(k0**2))

crossing_ideal

crossing_ideal(wl: FloatArrayLike = WL_C) -> SDict

Ideal waveguide crossing model.

in1 o2 o3 out0 o4 out1 o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C

Returns:

Type Description
SDict

The crossing s-matrix

Examples:

Ideal crossing:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.crossing_ideal(wl=wl)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s.get(("o1", "o2"), jnp.zeros_like(wl))) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/crossings.py
@jax.jit
@validate_call
def crossing_ideal(wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal waveguide crossing model.

    ```{svgbob}
            in1
             o2
              *
              |
              |
     o1 *-----+-----* o3
    in0       |     out0
              |
              *
             o4
           out1
    ```

    Args:
        wl: Wavelength in micrometers.

    Returns:
        The crossing s-matrix

    Examples:
        Ideal crossing:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import jax.numpy as jnp
    import sax

    sax.set_port_naming_strategy("optical")

    wl = sax.wl_c()
    s = sax.models.crossing_ideal(wl=wl)
    thru = np.abs(s[("o1", "o3")]) ** 2
    cross = np.abs(s.get(("o1", "o2"), jnp.zeros_like(wl))) ** 2
    plt.figure()
    plt.plot(wl, thru, label="thru")
    plt.plot(wl, cross, label="cross")
    plt.xlabel("Wavelength [μm]")
    plt.ylabel("Power")
    plt.legend()
    ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o3): one,
            (p.o2, p.o4): one,
        }
    )

electrical_open

electrical_open(*, f: FloatArrayLike = DEFAULT_FREQUENCY, n_ports: int = 1) -> SDict

Electrical open connection Sax model.

Useful for specifying some ports to remain open while not exposing them for connections in circuits.

Parameters:

Name Type Description Default
f FloatArrayLike

Array of frequency points in Hz

DEFAULT_FREQUENCY
n_ports int

Number of ports to set as opened

1

Returns:

Type Description
SDict

S-dictionary where \(S = I_\text{n_ports}\)

References

Pozar

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True, static_argnames=("n_ports"))
def electrical_open(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    n_ports: int = 1,
) -> sax.SDict:
    r"""Electrical open connection Sax model.

    Useful for specifying some ports to remain open while not exposing
    them for connections in circuits.

    Args:
        f: Array of frequency points in Hz
        n_ports: Number of ports to set as opened

    Returns:
        S-dictionary where $S = I_\text{n_ports}$

    References:
        Pozar
    """
    return gamma_0_load(f=f, gamma_0=1, n_ports=n_ports)

electrical_short

electrical_short(*, f: FloatArrayLike = DEFAULT_FREQUENCY, n_ports: int = 1) -> SDict

Electrical short connection Sax model.

Parameters:

Name Type Description Default
f FloatArrayLike

Array of frequency points in Hz

DEFAULT_FREQUENCY
n_ports int

Number of ports to set as shorted

1

Returns:

Type Description
SDict

S-dictionary where \(S = -I_\text{n_ports}\)

References

Pozar

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True, static_argnames=("n_ports"))
def electrical_short(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    n_ports: int = 1,
) -> sax.SDict:
    r"""Electrical short connection Sax model.

    Args:
        f: Array of frequency points in Hz
        n_ports: Number of ports to set as shorted

    Returns:
        S-dictionary where $S = -I_\text{n_ports}$

    References:
        Pozar
    """
    return gamma_0_load(f=f, gamma_0=-1, n_ports=n_ports)

ellipk_ratio

ellipk_ratio(m: FloatArrayLike) -> Array

Ratio of complete elliptic integrals of the first kind K(m) / K(1-m).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def ellipk_ratio(m: sax.FloatArrayLike) -> jax.Array:
    """Ratio of complete elliptic integrals of the first kind K(m) / K(1-m)."""
    if jaxellip is None:
        msg = (
            "jaxellip is required for RF models. Install it with `pip install sax[rf]`"
        )
        raise ImportError(msg)
    m_arr = jnp.asarray(m, dtype=float)
    return jaxellip.ellipk(m_arr) / jaxellip.ellipk(1 - m_arr)

gamma_0_load

gamma_0_load(
    *, f: FloatArrayLike = DEFAULT_FREQUENCY, gamma_0: Complex = 0, n_ports: int = 1
) -> SType

Connection with given reflection coefficient.

Parameters:

Name Type Description Default
f FloatArrayLike

Array of frequency points in Hz

DEFAULT_FREQUENCY
gamma_0 Complex

Reflection coefficient Γ₀ of connection

0
n_ports int

Number of ports in component. The diagonal ports of the matrix are set to Γ₀ and the off-diagonal ports to 0.

1

Returns:

Type Description
SType

sax.SType: S-parameters dictionary where \(S = \Gamma_0I_\text{n_ports}\)

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
import jaxellip  # pyright: ignore[reportMissingImports]
import scipy.constants


sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
gamma_0 = 0.5 * np.exp(1j * np.pi / 4)
s = sax.models.rf.gamma_0_load(f=f, gamma_0=gamma_0, n_ports=2)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
plt.plot(f / 1e9, np.abs(s[("o2", "o1")]), label="|S21|")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Magnitude")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True, static_argnames=("n_ports"))
def gamma_0_load(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    gamma_0: sax.Complex = 0,
    n_ports: int = 1,
) -> sax.SType:
    r"""Connection with given reflection coefficient.

    Args:
        f: Array of frequency points in Hz
        gamma_0: Reflection coefficient Γ₀ of connection
        n_ports: Number of ports in component. The diagonal ports of the matrix
            are set to Γ₀ and the off-diagonal ports to 0.

    Returns:
        sax.SType: S-parameters dictionary where $S = \Gamma_0I_\text{n_ports}$

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax
        import jaxellip  # pyright: ignore[reportMissingImports]
        import scipy.constants


        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        gamma_0 = 0.5 * np.exp(1j * np.pi / 4)
        s = sax.models.rf.gamma_0_load(f=f, gamma_0=gamma_0, n_ports=2)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o1")]), label="|S21|")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Magnitude")
        plt.legend()
        ```
    """
    f = jnp.asarray(f)
    f_flat = f.ravel()
    sdict = {
        (f"o{i}", f"o{i}"): jnp.full(f_flat.shape[0], gamma_0)
        for i in range(1, n_ports + 1)
    }
    sdict |= {
        (f"o{i}", f"o{j}"): jnp.zeros(f_flat.shape[0], dtype=complex)
        for i in range(1, n_ports + 1)
        for j in range(i + 1, n_ports + 1)
    }
    return sax.reciprocal({k: v.reshape(*f.shape) for k, v in sdict.items()})

grating_coupler

grating_coupler(
    *,
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    loss: FloatArrayLike = 0.0,
    reflection: FloatArrayLike = 0.0,
    reflection_fiber: FloatArrayLike = 0.0,
    bandwidth: FloatArrayLike = 0.04,
) -> SDict

Grating coupler model for fiber-chip coupling.

out0 o2 fiber o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Operating wavelength in micrometers. Can be a scalar or array for spectral analysis. Defaults to 1.55 μm.

WL_C
wl0 FloatArrayLike

Center wavelength in micrometers where peak transmission occurs. This is the design wavelength of the grating. Defaults to 1.55 μm.

WL_C
loss FloatArrayLike

Insertion loss in dB at the center wavelength. Includes coupling efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.

0.0
reflection FloatArrayLike

Reflection coefficient from the waveguide side (chip interface). Represents reflections back into the waveguide from grating discontinuities. Range: 0 to 1. Defaults to 0.0.

0.0
reflection_fiber FloatArrayLike

Reflection coefficient from the fiber side (top interface). Represents reflections back toward the fiber from the grating surface. Range: 0 to 1. Defaults to 0.0.

0.0
bandwidth FloatArrayLike

3dB bandwidth in micrometers. Determines the spectral width of the Gaussian transmission profile. Typical values: 20-50 nm. Defaults to 40e-3 μm (40 nm).

0.04

Returns:

Type Description
SDict

The grating coupler s-matrix

Examples:

Basic grating coupler:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = np.linspace(1.5, 1.6, 101)
s = sax.models.grating_coupler(
    wl=wl,
    loss=3.0,
    bandwidth=0.035,
)
plt.figure()
plt.plot(wl, np.abs(s[("o1", "o2")]) ** 2, label="Transmission")
plt.xlabel("Wavelength (μm)")
plt.ylabel("Power")
plt.legend()
Note

The transmission power profile follows a Gaussian shape: P(λ) = P₀ * exp(-((λ-λ₀)/σ)²)

The amplitude transmission is: A(λ) = A₀ * exp(-((λ-λ₀)/σ)²/2)

Where σ = bandwidth / (2√(2ln(2))) converts FWHM to Gaussian width.

This model assumes:

  • Gaussian spectral response (typical for uniform gratings)
  • Wavelength-independent loss coefficient
  • Linear polarization (TE or TM)
  • Single-mode fiber coupling
  • No higher-order diffraction effects
Source code in src/sax/models/couplers.py
@jax.jit
@validate_call
def grating_coupler(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    loss: sax.FloatArrayLike = 0.0,
    reflection: sax.FloatArrayLike = 0.0,
    reflection_fiber: sax.FloatArrayLike = 0.0,
    bandwidth: sax.FloatArrayLike = 40e-3,
) -> sax.SDict:
    """Grating coupler model for fiber-chip coupling.

    ```{svgbob}
                  out0
                  o2
             /  /  /  /
            /  /  /  /  fiber
           /  /  /  /
           _   _   _
     o1  _| |_| |_| |__
    in0
    ```

    Args:
        wl: Operating wavelength in micrometers. Can be a scalar or array for
            spectral analysis. Defaults to 1.55 μm.
        wl0: Center wavelength in micrometers where peak transmission occurs.
            This is the design wavelength of the grating. Defaults to 1.55 μm.
        loss: Insertion loss in dB at the center wavelength. Includes coupling
            efficiency losses, scattering, and mode mismatch. Defaults to 0.0 dB.
        reflection: Reflection coefficient from the waveguide side (chip interface).
            Represents reflections back into the waveguide from grating discontinuities.
            Range: 0 to 1. Defaults to 0.0.
        reflection_fiber: Reflection coefficient from the fiber side (top interface).
            Represents reflections back toward the fiber from the grating surface.
            Range: 0 to 1. Defaults to 0.0.
        bandwidth: 3dB bandwidth in micrometers. Determines the spectral width
            of the Gaussian transmission profile. Typical values: 20-50 nm.
            Defaults to 40e-3 μm (40 nm).

    Returns:
        The grating coupler s-matrix

    Examples:
        Basic grating coupler:

    ```python
    # mkdocs: render
    import matplotlib.pyplot as plt
    import numpy as np
    import sax

    sax.set_port_naming_strategy("optical")

    wl = np.linspace(1.5, 1.6, 101)
    s = sax.models.grating_coupler(
        wl=wl,
        loss=3.0,
        bandwidth=0.035,
    )
    plt.figure()
    plt.plot(wl, np.abs(s[("o1", "o2")]) ** 2, label="Transmission")
    plt.xlabel("Wavelength (μm)")
    plt.ylabel("Power")
    plt.legend()
    ```

    Note:
        The transmission power profile follows a Gaussian shape:
        P(λ) = P₀ * exp(-((λ-λ₀)/σ)²)

        The amplitude transmission is:
        A(λ) = A₀ * exp(-((λ-λ₀)/σ)²/2)

        Where σ = bandwidth / (2*√(2*ln(2))) converts FWHM to Gaussian width.

        This model assumes:

        - Gaussian spectral response (typical for uniform gratings)
        - Wavelength-independent loss coefficient
        - Linear polarization (TE or TM)
        - Single-mode fiber coupling
        - No higher-order diffraction effects

    """
    one = jnp.ones_like(wl)
    reflection = jnp.asarray(reflection) * one
    reflection_fiber = jnp.asarray(reflection_fiber) * one
    amplitude = jnp.asarray(10 ** (-loss / 20)) * one
    wl0_array = jnp.asarray(wl0) * one
    sigma = jnp.asarray(bandwidth / (2 * jnp.sqrt(2 * jnp.log(2)))) * one
    transmission = amplitude * jnp.exp(-((wl - wl0_array) ** 2) / (4 * sigma**2))
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.in0): reflection,
            (p.in0, p.out0): transmission,
            (p.out0, p.in0): transmission,
            (p.out0, p.out0): reflection_fiber,
        }
    )

ideal_probe

ideal_probe(wl: FloatArrayLike = WL_C) -> SDict

Ideal 4-port measurement probe with 100% transmission and 100% tap coupling.

This is an unphysical component designed for debugging and measurement purposes. It intercepts a connection and provides access to both forward and backward traveling waves without affecting the signal propagation.

          in ─────────────────── out
           │                     │
           │     (ideal tap)     │
           │                     │
         tap_bwd               tap_fwd

The probe has the following behavior: - Full transmission from in to out (and vice versa) - Forward tap (tap_fwd) copies the signal traveling from in toward out - Backward tap (tap_bwd) copies the signal traveling from out toward in - No reflections at any port - No cross-coupling between tap ports

Note

This S-matrix is NOT unitary (violates energy conservation). This is intentional—it's a measurement tool, not a physical device.

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers. Defaults to 1.55 μm.

WL_C

Returns:

Type Description
SDict

S-parameter dictionary for the ideal probe.

Example

Probes are typically not used directly, but via the probes argument to sax.circuit():

circuit_fn, info = sax.circuit(
    netlist,
    models=models,
    probes={"mid": "wg1,out"},
)
# Circuit now has additional ports: mid_fwd, mid_bwd
Source code in src/sax/models/probes.py
@validate_call
def ideal_probe(wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal 4-port measurement probe with 100% transmission and 100% tap coupling.

    This is an unphysical component designed for debugging and measurement purposes.
    It intercepts a connection and provides access to both forward and backward
    traveling waves without affecting the signal propagation.

    ```
              in ─────────────────── out
               │                     │
               │     (ideal tap)     │
               │                     │
             tap_bwd               tap_fwd
    ```

    The probe has the following behavior:
    - Full transmission from `in` to `out` (and vice versa)
    - Forward tap (`tap_fwd`) copies the signal traveling from `in` toward `out`
    - Backward tap (`tap_bwd`) copies the signal traveling from `out` toward `in`
    - No reflections at any port
    - No cross-coupling between tap ports

    Note:
        This S-matrix is NOT unitary (violates energy conservation). This is
        intentional—it's a measurement tool, not a physical device.

    Args:
        wl: Wavelength in micrometers. Defaults to 1.55 μm.

    Returns:
        S-parameter dictionary for the ideal probe.

    Example:
        Probes are typically not used directly, but via the `probes` argument
        to `sax.circuit()`:

        ```python
        circuit_fn, info = sax.circuit(
            netlist,
            models=models,
            probes={"mid": "wg1,out"},
        )
        # Circuit now has additional ports: mid_fwd, mid_bwd
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    zero = jnp.zeros_like(jnp.asarray(wl))
    return {
        # Through path: full transmission
        ("in", "out"): one,
        ("out", "in"): one,
        # Forward tap: copies signal from in→out direction
        ("in", "tap_fwd"): one,
        ("tap_fwd", "in"): one,
        # Backward tap: copies signal from out→in direction
        ("out", "tap_bwd"): one,
        ("tap_bwd", "out"): one,
        # No cross-coupling between taps
        ("tap_fwd", "tap_bwd"): zero,
        ("tap_bwd", "tap_fwd"): zero,
        # No coupling from taps to opposite main port
        ("tap_fwd", "out"): zero,
        ("tap_bwd", "in"): zero,
        # No reflections
        ("in", "in"): zero,
        ("out", "out"): zero,
        ("tap_fwd", "tap_fwd"): zero,
        ("tap_bwd", "tap_bwd"): zero,
    }

impedance

impedance(
    *, f: FloatArrayLike = DEFAULT_FREQUENCY, z: ComplexLike = 50, z0: ComplexLike = 50
) -> SDict

Generalized two-port impedance element.

Parameters:

Name Type Description Default
f FloatArrayLike

Frequency in Hz

DEFAULT_FREQUENCY
z ComplexLike

Impedance in Ω

50
z0 ComplexLike

Reference impedance in Ω.

50

Returns:

Type Description
SDict

S-dictionary representing the impedance element

References

Pozar

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
import jaxellip  # pyright: ignore[reportMissingImports]
import scipy.constants


sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
s = sax.models.rf.impedance(f=f, z=75, z0=50)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Magnitude")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def impedance(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    z: sax.ComplexLike = 50,
    z0: sax.ComplexLike = 50,
) -> sax.SDict:
    r"""Generalized two-port impedance element.

    Args:
        f: Frequency in Hz
        z: Impedance in Ω
        z0: Reference impedance in Ω.

    Returns:
        S-dictionary representing the impedance element

    References:
        Pozar

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax
        import jaxellip  # pyright: ignore[reportMissingImports]
        import scipy.constants


        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        s = sax.models.rf.impedance(f=f, z=75, z0=50)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Magnitude")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(f))
    sdict = {
        ("o1", "o1"): z / (z + 2 * z0) * one,
        ("o1", "o2"): 2 * z0 / (2 * z0 + z) * one,
        ("o2", "o2"): z / (z + 2 * z0) * one,
    }
    return sax.reciprocal(sdict)

inductor

inductor(
    *,
    f: FloatArrayLike = DEFAULT_FREQUENCY,
    inductance: FloatLike = 1e-12,
    z0: ComplexLike = 50,
) -> SDict

Ideal two-port inductor model.

Parameters:

Name Type Description Default
f FloatArrayLike

Frequency in Hz

DEFAULT_FREQUENCY
inductance FloatLike

Inductance in Henries

1e-12
z0 ComplexLike

Reference impedance in Ω.

50

Returns:

Type Description
SDict

S-dictionary representing the inductor element

References

Pozar

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
import jaxellip  # pyright: ignore[reportMissingImports]
import scipy.constants


sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
s = sax.models.rf.inductor(f=f, inductance=1e-9, z0=50)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Magnitude")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def inductor(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    inductance: sax.FloatLike = 1e-12,
    z0: sax.ComplexLike = 50,
) -> sax.SDict:
    r"""Ideal two-port inductor model.

    Args:
        f: Frequency in Hz
        inductance: Inductance in Henries
        z0: Reference impedance in Ω.

    Returns:
        S-dictionary representing the inductor element

    References:
        Pozar

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax
        import jaxellip  # pyright: ignore[reportMissingImports]
        import scipy.constants


        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        s = sax.models.rf.inductor(f=f, inductance=1e-9, z0=50)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Magnitude")
        plt.legend()
        ```
    """
    angular_frequency = 2 * jnp.pi * jnp.asarray(f)
    inductor_impedance = 1j * angular_frequency * inductance
    return impedance(f=f, z=inductor_impedance, z0=z0)

isolator

isolator(
    *,
    wl: FloatArrayLike = WL_C,
    insertion_loss_dB: FloatArrayLike = 0.0,
    isolation_dB: FloatArrayLike = 40.0,
) -> SDict

Optical isolator model (non-reciprocal).

Transmits light in the forward direction (in0 -> out0) with low loss while blocking the reverse direction (out0 -> in0).

out0 o2 o1 in0

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
insertion_loss_dB FloatArrayLike

Forward insertion loss in dB. Defaults to 0.0 dB.

0.0
isolation_dB FloatArrayLike

Reverse isolation in dB. Higher values mean better blocking of backward-propagating light. Defaults to 40.0 dB.

40.0

Returns:

Type Description
SDict

S-matrix dictionary for the isolator.

Examples:

Isolator with 1 dB insertion loss and 30 dB isolation:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.isolator(
    wl=wl,
    insertion_loss_dB=1.0,
    isolation_dB=30.0,
)
fwd = np.abs(s[("o1", "o2")]) ** 2
bwd = np.abs(s[("o2", "o1")]) ** 2
plt.figure()
plt.plot(wl, 10 * np.log10(fwd), label="forward")
plt.plot(wl, 10 * np.log10(bwd + 1e-10), label="backward")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Transmission [dB]")
plt.legend()
Source code in src/sax/models/isolators.py
@jax.jit
@validate_call
def isolator(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    insertion_loss_dB: sax.FloatArrayLike = 0.0,
    isolation_dB: sax.FloatArrayLike = 40.0,
) -> sax.SDict:
    """Optical isolator model (non-reciprocal).

    Transmits light in the forward direction (in0 -> out0) with low loss
    while blocking the reverse direction (out0 -> in0).

    ```{svgbob}
    in0             out0
     o1 ====>====== o2
    ```

    Args:
        wl: Wavelength in micrometers.
        insertion_loss_dB: Forward insertion loss in dB. Defaults to 0.0 dB.
        isolation_dB: Reverse isolation in dB. Higher values mean better
            blocking of backward-propagating light. Defaults to 40.0 dB.

    Returns:
        S-matrix dictionary for the isolator.

    Examples:
        Isolator with 1 dB insertion loss and 30 dB isolation:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.isolator(
            wl=wl,
            insertion_loss_dB=1.0,
            isolation_dB=30.0,
        )
        fwd = np.abs(s[("o1", "o2")]) ** 2
        bwd = np.abs(s[("o2", "o1")]) ** 2
        plt.figure()
        plt.plot(wl, 10 * np.log10(fwd), label="forward")
        plt.plot(wl, 10 * np.log10(bwd + 1e-10), label="backward")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Transmission [dB]")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    forward = jnp.asarray(10 ** (-insertion_loss_dB / 20), dtype=complex) * one
    backward = jnp.asarray(10 ** (-isolation_dB / 20), dtype=complex) * one
    return {
        ("o1", "o2"): forward,
        ("o2", "o1"): backward,
    }

lc_shunt_component

lc_shunt_component(
    f: FloatArrayLike = 5000000000.0,
    inductance: FloatLike = 1e-09,
    capacitance: FloatLike = 1e-12,
    z0: FloatLike = 50,
) -> SDict

SAX component for a 1-port shunted LC resonator.

o1 C C C L
Source code in src/sax/models/rf.py
@jax.jit
def lc_shunt_component(
    f: sax.FloatArrayLike = 5e9,
    inductance: sax.FloatLike = 1e-9,
    capacitance: sax.FloatLike = 1e-12,
    z0: sax.FloatLike = 50,
) -> sax.SDict:
    """SAX component for a 1-port shunted LC resonator.

    ```{svgbob}
             o1
             *
             |
        +----+----+
        |         |
       --- C      C L
       ---        C
        |         |
        +----+----+
             |
           -----
            ---
             -
    ```
    """
    f = jnp.asarray(f)
    instances = {
        "L": inductor(f=f, inductance=inductance, z0=z0),
        "C": capacitor(f=f, capacitance=capacitance, z0=z0),
        "gnd": electrical_short(f=f, n_ports=1),
        "tee_1": tee(f=f),
        "tee_2": tee(f=f),
    }
    connections = {
        "L,o1": "tee_1,o1",
        "C,o1": "tee_1,o2",
        "L,o2": "tee_2,o1",
        "C,o2": "tee_2,o2",
        "gnd,o1": "tee_2,o3",
    }
    ports = {
        "o1": "tee_1,o3",
    }

    return sax.backends.evaluate_circuit_fg((connections, ports), instances)

microstrip

microstrip(
    f: FloatArrayLike = DEFAULT_FREQUENCY,
    length: FloatLike = 1000.0,
    width: FloatLike = 10.0,
    substrate_thickness: FloatLike = 500.0,
    thickness: FloatLike = 0.2,
    ep_r: FloatLike = 11.45,
    tand: FloatLike = 0.0,
) -> SDict

S-parameter model for a straight microstrip transmission line.

Computes S-parameters analytically using the Hammerstad-Jensen closed-form expressions for effective permittivity and characteristic impedance, as described in Pozar. Conductor thickness corrections follow Gupta et al.

References

Hammerstad & Jensen; Pozar, ch. 3, §3.8; Gupta, Garg, Bahl & Bhartia, §2.2.4

Parameters:

Name Type Description Default
f FloatArrayLike

Array of frequency points in Hz.

DEFAULT_FREQUENCY
length FloatLike

Physical length in µm.

1000.0
width FloatLike

Strip width in µm.

10.0
substrate_thickness FloatLike

Substrate height in µm.

500.0
thickness FloatLike

Conductor thickness in µm (default 0.2 µm = 200 nm).

0.2
ep_r FloatLike

Relative permittivity of the substrate (default 11.45 for Si).

11.45
tand FloatLike

Dielectric loss tangent (default 0 — lossless).

0.0

Returns:

Type Description
SDict

sax.SDict: S-parameters dictionary.

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def microstrip(
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    length: sax.FloatLike = 1000.0,
    width: sax.FloatLike = 10.0,
    substrate_thickness: sax.FloatLike = 500.0,
    thickness: sax.FloatLike = 0.2,
    ep_r: sax.FloatLike = 11.45,
    tand: sax.FloatLike = 0.0,
) -> sax.SDict:
    r"""S-parameter model for a straight microstrip transmission line.

    Computes S-parameters analytically using the Hammerstad-Jensen closed-form
    expressions for effective permittivity and characteristic impedance, as
    described in Pozar.
    Conductor thickness corrections follow
    Gupta et al.

    References:
        Hammerstad & Jensen;
        Pozar, ch. 3, §3.8;
        Gupta, Garg, Bahl & Bhartia, §2.2.4

    Args:
        f: Array of frequency points in Hz.
        length: Physical length in µm.
        width: Strip width in µm.
        substrate_thickness: Substrate height in µm.
        thickness: Conductor thickness in µm (default 0.2 µm = 200 nm).
        ep_r: Relative permittivity of the substrate (default 11.45 for Si).
        tand: Dielectric loss tangent (default 0 — lossless).

    Returns:
        sax.SDict: S-parameters dictionary.
    """
    f = jnp.asarray(f)
    f_flat = f.ravel()

    w_m = width * 1e-6
    h_m = substrate_thickness * 1e-6
    t_m = thickness * 1e-6
    length_m = jnp.asarray(length) * 1e-6

    ep_eff = microstrip_epsilon_eff(w_m, h_m, ep_r)
    _w_eff, ep_eff_t, z0_val = microstrip_thickness_correction(
        w_m, h_m, t_m, ep_r, ep_eff
    )

    gamma = propagation_constant(f_flat, ep_eff_t, tand=tand, ep_r=ep_r)
    s11, s21 = transmission_line_s_params(gamma, z0_val, length_m)

    sdict: sax.SDict = {
        ("o1", "o1"): s11.reshape(f.shape),
        ("o1", "o2"): s21.reshape(f.shape),
        ("o2", "o2"): s11.reshape(f.shape),
    }
    return sax.reciprocal(sdict)

microstrip_epsilon_eff

microstrip_epsilon_eff(
    w: FloatArrayLike, h: FloatArrayLike, ep_r: FloatArrayLike
) -> Array

Effective permittivity of a microstrip line.

Uses the Hammerstad-Jensen formula as given in Pozar.

\[ \varepsilon_{\mathrm{eff}} = \frac{\varepsilon_r + 1}{2} + \frac{\varepsilon_r - 1}{2} \left(\frac{1}{\sqrt{1 + 12h/w}} + 0.04(1 - w/h)^2 \Theta(1 - w/h)\right) \]

where the last term contributes only for narrow strips (\(w/h < 1\)).

References

Hammerstad & Jensen; Pozar, Eqs. 3.195-3.196.

Parameters:

Name Type Description Default
w FloatArrayLike

Strip width (m).

required
h FloatArrayLike

Substrate height (m).

required
ep_r FloatArrayLike

Relative permittivity of the substrate.

required

Returns:

Type Description
Array

Effective permittivity (dimensionless).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def microstrip_epsilon_eff(
    w: sax.FloatArrayLike,
    h: sax.FloatArrayLike,
    ep_r: sax.FloatArrayLike,
) -> jax.Array:
    r"""Effective permittivity of a microstrip line.

    Uses the Hammerstad-Jensen formula as given in Pozar.

    $$
    \varepsilon_{\mathrm{eff}} = \frac{\varepsilon_r + 1}{2}
        + \frac{\varepsilon_r - 1}{2} \left(\frac{1}{\sqrt{1 + 12h/w}}
        + 0.04(1 - w/h)^2 \Theta(1 - w/h)\right)
    $$

    where the last term contributes only for narrow strips ($w/h < 1$).

    References:
        Hammerstad & Jensen;
        Pozar, Eqs. 3.195-3.196.

    Args:
        w: Strip width (m).
        h: Substrate height (m).
        ep_r: Relative permittivity of the substrate.

    Returns:
        Effective permittivity (dimensionless).
    """
    w = jnp.asarray(w, dtype=float)
    h = jnp.asarray(h, dtype=float)
    ep_r = jnp.asarray(ep_r, dtype=float)

    u = w / h
    f_u = 1.0 / jnp.sqrt(1.0 + 12.0 / u)

    narrow_correction = 0.04 * (1.0 - u) ** 2
    f_u = jnp.where(u < 1.0, f_u + narrow_correction, f_u)

    return (ep_r + 1.0) / 2.0 + (ep_r - 1.0) / 2.0 * f_u

microstrip_thickness_correction

microstrip_thickness_correction(
    w: FloatArrayLike,
    h: FloatArrayLike,
    t: FloatArrayLike,
    ep_r: FloatArrayLike,
    ep_eff: FloatArrayLike,
) -> tuple[Array, Array, Array]

Conductor thickness correction for a microstrip line.

Uses the widely-adopted Schneider correction as presented in Pozar and Gupta et al.

\[ \begin{aligned} w_e &= w + \frac{t}{\pi} \ln\frac{4e}{\sqrt{(t/h)^2 + (t/(wPI + 1.1tPI))^2}} \\ \varepsilon_{\mathrm{eff},t} &= \varepsilon_{\mathrm{eff}} - \frac{(\varepsilon_r - 1)\,t/h} {4.6\,\sqrt{w/h}} \end{aligned} \]

Then the corrected \(Z_0\) is computed with the effective width \(w_e\) and corrected \(\varepsilon_{\mathrm{eff},t}\).

References

Pozar, §3.8; Gupta, Garg, Bahl & Bhartia

Parameters:

Name Type Description Default
w FloatArrayLike

Strip width (m).

required
h FloatArrayLike

Substrate height (m).

required
t FloatArrayLike

Conductor thickness (m).

required
ep_r FloatArrayLike

Relative permittivity of the substrate.

required
ep_eff FloatArrayLike

Uncorrected effective permittivity.

required

Returns:

Type Description
Array

(w_eff, ep_eff_t, z0_t) — effective width (m),

Array

thickness-corrected effective permittivity,

Array

and characteristic impedance (Ω).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def microstrip_thickness_correction(
    w: sax.FloatArrayLike,
    h: sax.FloatArrayLike,
    t: sax.FloatArrayLike,
    ep_r: sax.FloatArrayLike,
    ep_eff: sax.FloatArrayLike,
) -> tuple[jax.Array, jax.Array, jax.Array]:
    r"""Conductor thickness correction for a microstrip line.

    Uses the widely-adopted Schneider correction as presented in
    Pozar and Gupta et al.



    $$
    \begin{aligned}
        w_e &= w + \frac{t}{\pi}
    \ln\frac{4e}{\sqrt{(t/h)^2 + (t/(wPI + 1.1tPI))^2}} \\
    \varepsilon_{\mathrm{eff},t}
    &= \varepsilon_{\mathrm{eff}}
    - \frac{(\varepsilon_r - 1)\,t/h}
    {4.6\,\sqrt{w/h}}
    \end{aligned}
    $$


    Then the corrected $Z_0$ is computed with the effective width
    $w_e$ and corrected $\varepsilon_{\mathrm{eff},t}$.

    References:
        Pozar, §3.8;
        Gupta, Garg, Bahl & Bhartia

    Args:
        w: Strip width (m).
        h: Substrate height (m).
        t: Conductor thickness (m).
        ep_r: Relative permittivity of the substrate.
        ep_eff: Uncorrected effective permittivity.

    Returns:
        ``(w_eff, ep_eff_t, z0_t)`` — effective width (m),
        thickness-corrected effective permittivity,
        and characteristic impedance (Ω).
    """
    w = jnp.asarray(w, dtype=float)
    h = jnp.asarray(h, dtype=float)
    t = jnp.asarray(t, dtype=float)
    ep_r = jnp.asarray(ep_r, dtype=float)
    ep_eff = jnp.asarray(ep_eff, dtype=float)

    term = jnp.sqrt((t / h) ** 2 + (t / (w * jnp.pi + 1.1 * t * jnp.pi)) ** 2)
    term_safe = jnp.where(term < 1e-15, 1.0, term)
    w_eff = w + (t / jnp.pi) * jnp.log(4.0 * jnp.e / term_safe)

    ep_eff_t = ep_eff - (ep_r - 1.0) * t / h / (4.6 * jnp.sqrt(w / h))
    z0_t = microstrip_z0(w_eff, h, ep_eff_t)

    w_eff = jnp.where(t <= 0, w, w_eff)
    ep_eff_t = jnp.where(t <= 0, ep_eff, ep_eff_t)
    z0_t = jnp.where(t <= 0, microstrip_z0(w, h, ep_eff), z0_t)

    return w_eff, ep_eff_t, z0_t

microstrip_z0

microstrip_z0(w: FloatArrayLike, h: FloatArrayLike, ep_eff: FloatArrayLike) -> Array

Characteristic impedance of a microstrip line.

Uses the Hammerstad-Jensen approximation as given in Pozar.

\[ \begin{aligned} Z_0 = \begin{cases} \displaystyle\frac{60}{\sqrt{\varepsilon_{\mathrm{eff}}}} \ln\!\left(\frac{8h}{w} + \frac{w}{4h}\right) & w/h \le 1 \\[6pt] \displaystyle\frac{120\pi} {\sqrt{\varepsilon_{\mathrm{eff}}}\, \bigl[w/h + 1.393 + 0.667\ln(w/h + 1.444)\bigr]} & w/h \ge 1 \end{cases} \end{aligned} \]
References

Hammerstad & Jensen; Pozar, Eqs. 3.197-3.198.

Parameters:

Name Type Description Default
w FloatArrayLike

Strip width (m).

required
h FloatArrayLike

Substrate height (m).

required
ep_eff FloatArrayLike

Effective permittivity (see :func:microstrip_epsilon_eff).

required

Returns:

Type Description
Array

Characteristic impedance (Ω).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def microstrip_z0(
    w: sax.FloatArrayLike,
    h: sax.FloatArrayLike,
    ep_eff: sax.FloatArrayLike,
) -> jax.Array:
    r"""Characteristic impedance of a microstrip line.

    Uses the Hammerstad-Jensen approximation as given in
    Pozar.


    $$
    \begin{aligned}
        Z_0 = \begin{cases}
    \displaystyle\frac{60}{\sqrt{\varepsilon_{\mathrm{eff}}}}
    \ln\!\left(\frac{8h}{w} + \frac{w}{4h}\right)
    & w/h \le 1 \\[6pt]
    \displaystyle\frac{120\pi}
    {\sqrt{\varepsilon_{\mathrm{eff}}}\,
    \bigl[w/h + 1.393 + 0.667\ln(w/h + 1.444)\bigr]}
    & w/h \ge 1
    \end{cases}
    \end{aligned}
    $$


    References:
        Hammerstad & Jensen;
        Pozar, Eqs. 3.197-3.198.


    Args:
        w: Strip width (m).
        h: Substrate height (m).
        ep_eff: Effective permittivity (see :func:`microstrip_epsilon_eff`).

    Returns:
        Characteristic impedance (Ω).
    """
    w = jnp.asarray(w, dtype=float)
    h = jnp.asarray(h, dtype=float)
    ep_eff = jnp.asarray(ep_eff, dtype=float)

    u = w / h
    z_narrow = (60.0 / jnp.sqrt(ep_eff)) * jnp.log(8.0 / u + u / 4.0)
    z_wide = (
        120.0 * jnp.pi / (jnp.sqrt(ep_eff) * (u + 1.393 + 0.667 * jnp.log(u + 1.444)))
    )

    return jnp.where(u <= 1.0, z_narrow, z_wide)

mirror

mirror(
    *,
    wl: FloatArrayLike = WL_C,
    reflection: FloatArrayLike = 1.0,
    loss_dB: FloatArrayLike = 0.0,
) -> SDict

Ideal mirror model (reflector with default 100% reflection).

in0 o1 mirr r

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
reflection FloatArrayLike

Power reflection coefficient between 0 and 1. Defaults to 1.0 (perfect mirror).

1.0
loss_dB FloatArrayLike

Insertion loss in dB. Defaults to 0.0 dB.

0.0

Returns:

Type Description
SDict

S-matrix dictionary for the mirror.

Examples:

Perfect mirror:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mirror(wl=wl)
refl = np.abs(s[("o1", "o1")]) ** 2
plt.figure()
plt.plot(wl, refl, label="reflection")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/reflectors.py
@jax.jit
@validate_call
def mirror(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    reflection: sax.FloatArrayLike = 1.0,
    loss_dB: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    r"""Ideal mirror model (reflector with default 100% reflection).

    ```{svgbob}
    in0          ||
     o1 ========= ||
                 ||
              mirror
    ```

    Args:
        wl: Wavelength in micrometers.
        reflection: Power reflection coefficient between 0 and 1.
            Defaults to 1.0 (perfect mirror).
        loss_dB: Insertion loss in dB. Defaults to 0.0 dB.

    Returns:
        S-matrix dictionary for the mirror.

    Examples:
        Perfect mirror:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mirror(wl=wl)
        refl = np.abs(s[("o1", "o1")]) ** 2
        plt.figure()
        plt.plot(wl, refl, label="reflection")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    return reflector(wl=wl, reflection=reflection, loss_dB=loss_dB)

mmi1x2

mmi1x2(
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    fwhm: FloatArrayLike = 0.2,
    loss_dB: FloatArrayLike = 0.3,
) -> SDict

Realistic 1x2 MMI splitter model with dispersion and loss.

out1 o2 o1 in0 o3 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C
wl0 FloatArrayLike

The Center wavelength of the MMI

WL_C
fwhm FloatArrayLike

The Full width at half maximum or the MMI.

0.2
loss_dB FloatArrayLike

Insertion loss in dB at the center wavelength.

0.3

Returns:

Type Description
SDict

S-matrix dictionary representing the dispersive MMI splitter behavior.

Examples:

Basic 1x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi1x2(wl=wl)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi1x2(
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    fwhm: sax.FloatArrayLike = 0.2,
    loss_dB: sax.FloatArrayLike = 0.3,
) -> sax.SDict:
    r"""Realistic 1x2 MMI splitter model with dispersion and loss.

    ```{svgbob}
            +-------------+     out1
            |             |---* o2
     o1 *---|             |
    in0     |             |---* o3
            +-------------+     out0
    ```

    Args:
        wl: The wavelength in micrometers.
        wl0: The Center wavelength of the MMI
        fwhm: The Full width at half maximum or the MMI.
        loss_dB: Insertion loss in dB at the center wavelength.

    Returns:
        S-matrix dictionary representing the dispersive MMI splitter behavior.

    Examples:
        Basic 1x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi1x2(wl=wl)
        thru = np.abs(s[("o1", "o3")]) ** 2
        cross = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    thru = _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB) / 2**0.5

    p = sax.PortNamer(1, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o2): thru,
            (p.o1, p.o3): thru,
        }
    )

mmi1x2_ideal

mmi1x2_ideal(wl: FloatArrayLike = WL_C) -> SDict

Ideal 1x2 multimode interference (MMI) splitter model.

out1 o2 o1 in0 o3 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C

Returns:

Type Description
SDict

S-matrix dictionary representing the ideal MMI splitter behavior.

Examples:

Ideal 1x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi1x2_ideal(wl=wl)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi1x2_ideal(wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal 1x2 multimode interference (MMI) splitter model.

    ```{svgbob}
            +-------------+     out1
            |             |---* o2
     o1 *---|             |
    in0     |             |---* o3
            +-------------+     out0
    ```

    Args:
        wl: The wavelength in micrometers.

    Returns:
        S-matrix dictionary representing the ideal MMI splitter behavior.

    Examples:
        Ideal 1x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi1x2_ideal(wl=wl)
        thru = np.abs(s[("o1", "o3")]) ** 2
        cross = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    return splitter_ideal(wl=wl, coupling=0.5)

mmi2x2

mmi2x2(
    wl: FloatArrayLike = WL_C,
    wl0: FloatArrayLike = WL_C,
    fwhm: FloatArrayLike = 0.2,
    loss_dB: FloatArrayLike = 0.3,
    shift: FloatArrayLike = 0.0,
    loss_dB_cross: FloatArrayLike | None = None,
    loss_dB_thru: FloatArrayLike | None = None,
    splitting_ratio_cross: FloatArrayLike = 0.5,
    splitting_ratio_thru: FloatArrayLike = 0.5,
) -> SDict

Realistic 2x2 MMI coupler model with dispersion and asymmetry.

in1 o2 out1 o3 o1 in0 o4 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

wavelength in micrometers.

WL_C
wl0 FloatArrayLike

Center wavelength of the MMI in micrometers.

WL_C
fwhm FloatArrayLike

Full width at half maximum bandwidth in micrometers.

0.2
loss_dB FloatArrayLike

Insertion loss of the MMI at center wavelength.

0.3
shift FloatArrayLike

The peak shift of the cross-transmission in micrometers.

0.0
loss_dB_cross FloatArrayLike | None

Optional separate insertion loss in dB for cross ports. If None, uses loss_dB. Allows modeling of asymmetric loss.

None
loss_dB_thru FloatArrayLike | None

Optional separate insertion loss in dB for bar (through) ports. If None, uses loss_dB. Allows modeling of asymmetric loss.

None
splitting_ratio_cross FloatArrayLike

Power splitting ratio for cross ports (0 to 1). Allows modeling of imbalanced coupling. Defaults to 0.5.

0.5
splitting_ratio_thru FloatArrayLike

Power splitting ratio for bar ports (0 to 1). Allows modeling of imbalanced transmission. Defaults to 0.5.

0.5

Returns:

Type Description
SDict

S-matrix dictionary representing the realistic MMI coupler behavior.

Examples:

Basic 2x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi2x2(wl=wl, shift=0.001)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()

```

Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi2x2(
    wl: sax.FloatArrayLike = sax.WL_C,
    wl0: sax.FloatArrayLike = sax.WL_C,
    fwhm: sax.FloatArrayLike = 0.2,
    loss_dB: sax.FloatArrayLike = 0.3,
    shift: sax.FloatArrayLike = 0.0,
    loss_dB_cross: sax.FloatArrayLike | None = None,
    loss_dB_thru: sax.FloatArrayLike | None = None,
    splitting_ratio_cross: sax.FloatArrayLike = 0.5,
    splitting_ratio_thru: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Realistic 2x2 MMI coupler model with dispersion and asymmetry.

    ```{svgbob}
    in1     +-------------+     out1
     o2 *---|             |---* o3
            |             |
     o1 *---|             |---* o4
    in0     +-------------+     out0
    ```

    Args:
        wl: wavelength in micrometers.
        wl0: Center wavelength of the MMI in micrometers.
        fwhm: Full width at half maximum bandwidth in micrometers.
        loss_dB: Insertion loss of the MMI at center wavelength.
        shift: The peak shift of the cross-transmission in micrometers.
        loss_dB_cross: Optional separate insertion loss in dB for cross ports.
            If None, uses loss_dB. Allows modeling of asymmetric loss.
        loss_dB_thru: Optional separate insertion loss in dB for bar (through)
            ports. If None, uses loss_dB. Allows modeling of asymmetric loss.
        splitting_ratio_cross: Power splitting ratio for cross ports (0 to 1).
            Allows modeling of imbalanced coupling. Defaults to 0.5.
        splitting_ratio_thru: Power splitting ratio for bar ports (0 to 1).
            Allows modeling of imbalanced transmission. Defaults to 0.5.

    Returns:
        S-matrix dictionary representing the realistic MMI coupler behavior.

    Examples:
        Basic 2x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi2x2(wl=wl, shift=0.001)
        thru = np.abs(s[("o1", "o4")]) ** 2
        cross = np.abs(s[("o1", "o3")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    ```
    """
    loss_dB_cross = loss_dB_cross or loss_dB
    loss_dB_thru = loss_dB_thru or loss_dB

    # Convert splitting ratios from power to amplitude by taking the square root
    amplitude_ratio_thru = splitting_ratio_thru**0.5
    amplitude_ratio_cross = splitting_ratio_cross**0.5

    # _mmi_amp already includes the loss, so we don't need to apply it again
    thru = (
        _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB_thru) * amplitude_ratio_thru
    )
    cross = (
        1j
        * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB_cross)
        * amplitude_ratio_cross
    )

    p = sax.PortNamer(2, 2)
    return sax.reciprocal(
        {
            (p.o1, p.o3): thru,
            (p.o1, p.o4): cross,
            (p.o2, p.o3): cross,
            (p.o2, p.o4): thru,
        }
    )

mmi2x2_ideal

mmi2x2_ideal(*, wl: FloatArrayLike = WL_C) -> SDict

Ideal 2x2 multimode interference (MMI) coupler model.

in1 o2 out1 o3 o1 in0 o4 out0

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C

Returns:

Type Description
SDict

S-matrix dictionary representing the ideal MMI coupler behavior.

Examples:

Ideal 2x2 MMI:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.mmi2x2_ideal(wl=wl)
thru = np.abs(s[("o1", "o4")]) ** 2
cross = np.abs(s[("o1", "o3")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/mmis.py
@jax.jit
@validate_call
def mmi2x2_ideal(*, wl: sax.FloatArrayLike = sax.WL_C) -> sax.SDict:
    """Ideal 2x2 multimode interference (MMI) coupler model.

    ```{svgbob}
    in1     +-------------+     out1
     o2 *---|             |---* o3
            |             |
     o1 *---|             |---* o4
    in0     +-------------+     out0
    ```

    Args:
        wl: The wavelength in micrometers.

    Returns:
        S-matrix dictionary representing the ideal MMI coupler behavior.

    Examples:
        Ideal 2x2 MMI:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.mmi2x2_ideal(wl=wl)
        thru = np.abs(s[("o1", "o4")]) ** 2
        cross = np.abs(s[("o1", "o3")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    return coupler_ideal(wl=wl, coupling=0.5)

model_2port

model_2port(p1: Name, p2: Name) -> SDictModel

Generate a general 2-port model with 100% transmission.

p1 p2

Parameters:

Name Type Description Default
p1 Name

Name of the first port (typically in0 or o1).

required
p2 Name

Name of the second port (typically out0 or o2).

required

Returns:

Type Description
SDictModel

A 2-port model

Source code in src/sax/models/factories.py
@validate_call
def model_2port(p1: sax.Name, p2: sax.Name) -> sax.SDictModel:
    """Generate a general 2-port model with 100% transmission.

    ```{svgbob}
     p1 *--------* p2
    ```

    Args:
        p1: Name of the first port (typically in0 or o1).
        p2: Name of the second port (typically out0 or o2).

    Returns:
        A 2-port model
    """

    @jax.jit
    @validate_call
    def model_2port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        return sax.reciprocal({(p1, p2): jnp.ones_like(wl)})

    return model_2port

model_3port

model_3port(p1: Name, p2: Name, p3: Name) -> SDictModel

Generate a general 3-port model (1x2 splitter).

p2 p1 p3

Parameters:

Name Type Description Default
p1 Name

Name of the input port (typically o1 or in0).

required
p2 Name

Name of the first output port (typically o2 or out1).

required
p3 Name

Name of the second output port (typically o3 or out0).

required

Returns:

Type Description
SDictModel

A 3-port model

Source code in src/sax/models/factories.py
@validate_call
def model_3port(p1: sax.Name, p2: sax.Name, p3: sax.Name) -> sax.SDictModel:
    """Generate a general 3-port model (1x2 splitter).

    ```{svgbob}
            +-----+--* p2
     p1 *---|     |
            +-----+--* p3
    ```

    Args:
        p1: Name of the input port (typically o1 or in0).
        p2: Name of the first output port (typically o2 or out1).
        p3: Name of the second output port (typically o3 or out0).

    Returns:
        A 3-port model
    """

    @jax.jit
    @validate_call
    def model_3port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        thru = jnp.ones_like(wl) / jnp.sqrt(2)
        return sax.reciprocal(
            {
                (p1, p2): thru,
                (p1, p3): thru,
            }
        )

    return model_3port

model_4port

model_4port(p1: Name, p2: Name, p3: Name, p4: Name) -> SDictModel

Generate a general 4-port model (2x2 coupler).

p2 p3 p1 p4

Parameters:

Name Type Description Default
p1 Name

Name of the first input port (typically o1 or in0).

required
p2 Name

Name of the second input port (typically o2 or in1).

required
p3 Name

Name of the first output port (typically o3 or out1).

required
p4 Name

Name of the second output port (typicall o4 or out0).

required

Returns:

Type Description
SDictModel

A 4-port model

Source code in src/sax/models/factories.py
@validate_call
def model_4port(
    p1: sax.Name, p2: sax.Name, p3: sax.Name, p4: sax.Name
) -> sax.SDictModel:
    """Generate a general 4-port model (2x2 coupler).

    ```{svgbob}
    p2 *---+-----+--* p3
           |     |
    p1 *---+-----+--* p4
    ```

    Args:
        p1: Name of the first input port (typically o1 or in0).
        p2: Name of the second input port (typically o2 or in1).
        p3: Name of the first output port (typically o3 or out1).
        p4: Name of the second output port (typicall o4 or out0).

    Returns:
        A 4-port model
    """

    @jax.jit
    @validate_call
    def model_4port(wl: sax.FloatArrayLike = 1.5) -> sax.SDict:
        wl = jnp.asarray(wl)
        thru = jnp.ones_like(wl) / jnp.sqrt(2)
        cross = 1j * thru
        return sax.reciprocal(
            {
                (p1, p4): thru,
                (p2, p3): thru,
                (p1, p3): cross,
                (p2, p4): cross,
            }
        )

    return model_4port

passthru

passthru(num_links: int, *, reciprocal: bool = True) -> SCooModel

Copy 100% of the power at each input port to its corresponding output port.

o ni 1 o ni . . . o2 in1 o N 2 o1 in0 o N 1

Parameters:

Name Type Description Default
num_links int

Number of independent pass-through links (input-output pairs). This creates a device with num_links inputs and num_links outputs.

required
reciprocal bool

If True, the device exhibits reciprocal behavior where transmission is identical in both directions. This is typical for passive devices like fibers and waveguides. Defaults to True.

True

Returns:

Type Description
SCooModel

A passthru model

Examples:

Create an 8×8 optical switch (straight-through state):

switch_thru = sax.models.passthru(8, reciprocal=True)
Si, Sj, Sx, port_map = switch_thru(wl=1.55)
# Each input passes straight to corresponding output
Source code in src/sax/models/factories.py
@validate_call
def passthru(
    num_links: int,
    *,
    reciprocal: bool = True,
) -> sax.SCooModel:
    """Copy 100% of the power at each input port to its corresponding output port.

    ```{svgbob}
    o_ni -1 *---+-----+--* o_ni
                |     |
                  .
                  .
                  .
                |     |
         o2 *---+-----+--* o_N -2
        in1     |     |
                |     |
         o1 *---+-----+--* o_N -1
        in0
    ```

    Args:
        num_links: Number of independent pass-through links (input-output pairs).
            This creates a device with num_links inputs and num_links outputs.
        reciprocal: If True, the device exhibits reciprocal behavior where
            transmission is identical in both directions. This is typical for
            passive devices like fibers and waveguides. Defaults to True.

    Returns:
        A passthru model

    Examples:
        Create an 8×8 optical switch (straight-through state):

        ```python
        switch_thru = sax.models.passthru(8, reciprocal=True)
        Si, Sj, Sx, port_map = switch_thru(wl=1.55)
        # Each input passes straight to corresponding output
        ```
    """
    passthru = unitary(
        num_links,
        num_links,
        reciprocal=reciprocal,
        diagonal=True,
    )
    passthru.__name__ = f"passthru_{num_links}_{num_links}"
    passthru.__qualname__ = f"passthru_{num_links}_{num_links}"
    return jax.jit(passthru)

phase_shifter

phase_shifter(
    wl: FloatArrayLike = WL_C,
    neff: FloatArrayLike = 2.34,
    voltage: FloatArrayLike = 0,
    length: FloatArrayLike = 10,
    loss: FloatArrayLike = 0.0,
) -> SDict

Simple voltage-controlled phase shifter model.

in0 o1 out0 o2

Parameters:

Name Type Description Default
wl FloatArrayLike

The wavelength in micrometers.

WL_C
neff FloatArrayLike

The Effective index of the unperturbed waveguide mode.

2.34
voltage FloatArrayLike

The Applied voltage in volts. The phase shift is assumed to be linearly proportional to voltage with a coefficient of π rad/V. Positive voltage increases the phase. Defaults to 0 V.

0
length FloatArrayLike

The length of the phase shifter in micrometers.

10
loss FloatArrayLike

Additional loss in dB introduced by the active region.

0.0

Returns:

Type Description
SDict

S-matrix dictionary containing the complex-valued transmission coefficient.

Examples:

Phase shifter:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.phase_shifter(wl=wl, loss=3.0)
thru = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/straight.py
@jax.jit
@validate_call
def phase_shifter(
    wl: sax.FloatArrayLike = sax.WL_C,
    neff: sax.FloatArrayLike = 2.34,
    voltage: sax.FloatArrayLike = 0,
    length: sax.FloatArrayLike = 10,
    loss: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    """Simple voltage-controlled phase shifter model.

    ```{svgbob}
    in0             out0
     o1 =========== o2
    ```

    Args:
        wl: The wavelength in micrometers.
        neff: The Effective index of the unperturbed waveguide mode.
        voltage: The Applied voltage in volts. The phase shift is assumed to be
            linearly proportional to voltage with a coefficient of π rad/V.
            Positive voltage increases the phase. Defaults to 0 V.
        length: The length of the phase shifter in micrometers.
        loss: Additional loss in dB introduced by the active region.

    Returns:
        S-matrix dictionary containing the complex-valued transmission coefficient.

    Examples:
        Phase shifter:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.phase_shifter(wl=wl, loss=3.0)
        thru = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    deltaphi = voltage * jnp.pi
    phase = 2 * jnp.pi * neff * length / wl + deltaphi
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.o1, p.o2): transmission,
        }
    )

propagation_constant

propagation_constant(
    f: FloatArrayLike,
    ep_eff: FloatArrayLike,
    tand: FloatArrayLike = 0.0,
    ep_r: FloatArrayLike = 1.0,
) -> Array

Complex propagation constant of a quasi-TEM transmission line.

For the general lossy case

\[ \gamma = \alpha_d + j\,\beta \]

where the dielectric attenuation is

\[ \alpha_d = \frac{\pi f}{C_M_S} \frac{\varepsilon_r}{\sqrt{\varepsilon_{\mathrm{eff}}}} \frac{\varepsilon_{\mathrm{eff}} - 1} {\varepsilon_r - 1} \tan\delta \]

and the phase constant is

\[ \beta = \frac{2\pi f}{C_M_S}\,\sqrt{\varepsilon_{\mathrm{eff}}} \]

For a superconducting line (\(\tan\delta = 0\)) the propagation is purely imaginary: \(\gamma = j\beta\).

References

Pozar, §3.8

Parameters:

Name Type Description Default
f FloatArrayLike

Frequency (Hz).

required
ep_eff FloatArrayLike

Effective permittivity.

required
tand FloatArrayLike

Dielectric loss tangent (default 0 — lossless).

0.0
ep_r FloatArrayLike

Substrate relative permittivity (only needed when tand > 0).

1.0

Returns:

Type Description
Array

Complex propagation constant \(\gamma\) (1/m).

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def propagation_constant(
    f: sax.FloatArrayLike,
    ep_eff: sax.FloatArrayLike,
    tand: sax.FloatArrayLike = 0.0,
    ep_r: sax.FloatArrayLike = 1.0,
) -> jax.Array:
    r"""Complex propagation constant of a quasi-TEM transmission line.

    For the general lossy case

    $$
    \gamma = \alpha_d + j\,\beta
    $$




    where the **dielectric attenuation** is



    $$
    \alpha_d = \frac{\pi f}{C_M_S}
    \frac{\varepsilon_r}{\sqrt{\varepsilon_{\mathrm{eff}}}}
    \frac{\varepsilon_{\mathrm{eff}} - 1}
    {\varepsilon_r - 1}
    \tan\delta
    $$




    and the **phase constant** is



    $$
    \beta = \frac{2\pi f}{C_M_S}\,\sqrt{\varepsilon_{\mathrm{eff}}}
    $$




    For a superconducting line ($\tan\delta = 0$) the propagation
    is purely imaginary: $\gamma = j\beta$.

    References:
        Pozar, §3.8

    Args:
        f: Frequency (Hz).
        ep_eff: Effective permittivity.
        tand: Dielectric loss tangent (default 0 — lossless).
        ep_r: Substrate relative permittivity (only needed when ``tand > 0``).

    Returns:
        Complex propagation constant $\gamma$ (1/m).
    """
    f = jnp.asarray(f, dtype=float)
    ep_eff = jnp.asarray(ep_eff, dtype=float)
    tand = jnp.asarray(tand, dtype=float)
    ep_r = jnp.asarray(ep_r, dtype=float)

    beta = 2.0 * jnp.pi * f * jnp.sqrt(ep_eff) / C_M_S

    denom = jnp.where(jnp.abs(ep_r - 1.0) < 1e-15, 1.0, ep_r - 1.0)
    alpha_d = (
        jnp.pi * f / C_M_S * (ep_r / jnp.sqrt(ep_eff)) * ((ep_eff - 1.0) / denom) * tand
    )
    alpha_d = jnp.where(jnp.abs(ep_r - 1.0) < 1e-15, 0.0, alpha_d)

    return alpha_d + 1j * beta

reflector

reflector(
    *,
    wl: FloatArrayLike = WL_C,
    reflection: FloatArrayLike = 0.5,
    loss_dB: FloatArrayLike = 0.0,
) -> SDict

Partial reflector / mirror model.

A 2-port component that reflects a fraction of the input power and transmits the rest (minus any loss).

in0 o1 mirror out0 o2

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
reflection FloatArrayLike

Power reflection coefficient between 0 and 1. 0 means full transmission, 1 means full reflection. Defaults to 0.5.

0.5
loss_dB FloatArrayLike

Insertion loss in dB applied to both reflected and transmitted amplitudes. Defaults to 0.0 dB.

0.0

Returns:

Type Description
SDict

S-matrix dictionary for the reflector.

Examples:

50% reflector:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.reflector(wl=wl, reflection=0.5)
thru = np.abs(s[("o1", "o2")]) ** 2
refl = np.abs(s[("o1", "o1")]) ** 2
plt.figure()
plt.plot(wl, thru, label="transmission")
plt.plot(wl, refl, label="reflection")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/reflectors.py
@jax.jit
@validate_call
def reflector(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    reflection: sax.FloatArrayLike = 0.5,
    loss_dB: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    r"""Partial reflector / mirror model.

    A 2-port component that reflects a fraction of the input power and
    transmits the rest (minus any loss).

    ```{svgbob}
    in0          |   out0
     o1 ========= | ========= o2
                 |
              mirror
    ```

    Args:
        wl: Wavelength in micrometers.
        reflection: Power reflection coefficient between 0 and 1.
            0 means full transmission, 1 means full reflection. Defaults to 0.5.
        loss_dB: Insertion loss in dB applied to both reflected and
            transmitted amplitudes. Defaults to 0.0 dB.

    Returns:
        S-matrix dictionary for the reflector.

    Examples:
        50% reflector:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.reflector(wl=wl, reflection=0.5)
        thru = np.abs(s[("o1", "o2")]) ** 2
        refl = np.abs(s[("o1", "o1")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="transmission")
        plt.plot(wl, refl, label="reflection")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    loss_amp = jnp.asarray(10 ** (-loss_dB / 20), dtype=complex) * one
    r = jnp.asarray(reflection**0.5, dtype=complex) * loss_amp
    t = jnp.asarray((1 - reflection) ** 0.5, dtype=complex) * loss_amp
    p = sax.PortNamer(1, 1)
    return sax.reciprocal(
        {
            (p.in0, p.in0): r,
            (p.in0, p.out0): t,
            (p.out0, p.out0): r,
        }
    )

resistor

resistor(
    *,
    f: FloatArrayLike = DEFAULT_FREQUENCY,
    resistance: FloatLike = 50,
    z0: ComplexLike = 50,
) -> SDict

Ideal two-port resistor model.

Parameters:

Name Type Description Default
f FloatArrayLike

Frequency in Hz

DEFAULT_FREQUENCY
resistance FloatLike

Resistance in Ohms

50
z0 ComplexLike

Reference impedance in Ω.

50

Returns:

Type Description
SDict

S-dictionary representing the resistor element

References

[@pozar2012]

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
s = sax.models.rf.resistor(f=f, resistance=100, z0=50)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Magnitude")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def resistor(
    *,
    f: sax.FloatArrayLike = DEFAULT_FREQUENCY,
    resistance: sax.FloatLike = 50,
    z0: sax.ComplexLike = 50,
) -> sax.SDict:
    r"""Ideal two-port resistor model.

    Args:
        f: Frequency in Hz
        resistance: Resistance in Ohms
        z0: Reference impedance in Ω.

    Returns:
        S-dictionary representing the resistor element

    References:
        [@pozar2012]

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        s = sax.models.rf.resistor(f=f, resistance=100, z0=50)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o1")]), label="|S11|")
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]), label="|S12|")
        plt.plot(f / 1e9, np.abs(s[("o2", "o2")]), label="|S22|")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Magnitude")
        plt.legend()
        ```
    """
    return impedance(f=f, z=resistance, z0=z0)

splitter_ideal

splitter_ideal(*, wl: FloatArrayLike = WL_C, coupling: FloatArrayLike = 0.5) -> SDict

Ideal 1x2 power splitter model.

out0 o2 in0 o1 out1 o3

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
coupling FloatArrayLike

Power coupling ratio between 0 and 1.

0.5

Returns:

Type Description
SDict

S-matrix dictionary containing the complex-valued cross/thru coefficients.

Examples:

Ideal 1x2 splitter:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.splitter_ideal(wl=wl, coupling=0.3)
thru = np.abs(s[("o1", "o3")]) ** 2
cross = np.abs(s[("o1", "o2")]) ** 2
plt.figure()
plt.plot(wl, thru, label="thru")
plt.plot(wl, cross, label="cross")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/splitters.py
@jax.jit
@validate_call
def splitter_ideal(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    coupling: sax.FloatArrayLike = 0.5,
) -> sax.SDict:
    r"""Ideal 1x2 power splitter model.

    ```{svgbob}
                    .--------- out0
                   /           o2
                  /   .-------
    in0 ---------'   /
     o1             (
        ---------.   \
                  \   '------- out1
                   \           o3
                    '---------
    ```

    Args:
        wl: Wavelength in micrometers.
        coupling: Power coupling ratio between 0 and 1.

    Returns:
        S-matrix dictionary containing the complex-valued cross/thru coefficients.

    Examples:
        Ideal 1x2 splitter:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.splitter_ideal(wl=wl, coupling=0.3)
        thru = np.abs(s[("o1", "o3")]) ** 2
        cross = np.abs(s[("o1", "o2")]) ** 2
        plt.figure()
        plt.plot(wl, thru, label="thru")
        plt.plot(wl, cross, label="cross")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    kappa = jnp.asarray(coupling**0.5) * one
    tau = jnp.asarray((1 - coupling) ** 0.5) * one

    p = sax.PortNamer(num_inputs=1, num_outputs=2)
    return sax.reciprocal(
        {
            (p.in0, p.out0): tau,
            (p.in0, p.out1): kappa,
        },
    )

tee

tee(*, f: FloatArrayLike = DEFAULT_FREQUENCY) -> SDict

Ideal three-port RF power divider/combiner (T-junction).

o2 o1 o3

Parameters:

Name Type Description Default
f FloatArrayLike

Array of frequency points in Hz

DEFAULT_FREQUENCY

Returns:

Type Description
SDict

S-dictionary representing ideal RF T-junction behavior

Examples:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax
import jaxellip  # pyright: ignore[reportMissingImports]
import scipy.constants


sax.set_port_naming_strategy("optical")

f = np.linspace(1e9, 10e9, 500)
s = sax.models.rf.tee(f=f)
plt.figure()
plt.plot(f / 1e9, np.abs(s[("o1", "o2")]) ** 2, label="|S12|^2")
plt.plot(f / 1e9, np.abs(s[("o1", "o3")]) ** 2, label="|S13|^2")
plt.plot(f / 1e9, np.abs(s[("o2", "o3")]) ** 2, label="|S23|^2")
plt.xlabel("Frequency [GHz]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def tee(*, f: sax.FloatArrayLike = DEFAULT_FREQUENCY) -> sax.SDict:
    """Ideal three-port RF power divider/combiner (T-junction).

    ```{svgbob}
            o2
            *
            |
            |
     o1 *---+---* o3
    ```

    Args:
        f: Array of frequency points in Hz

    Returns:
        S-dictionary representing ideal RF T-junction behavior

    Examples:
        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax
        import jaxellip  # pyright: ignore[reportMissingImports]
        import scipy.constants


        sax.set_port_naming_strategy("optical")

        f = np.linspace(1e9, 10e9, 500)
        s = sax.models.rf.tee(f=f)
        plt.figure()
        plt.plot(f / 1e9, np.abs(s[("o1", "o2")]) ** 2, label="|S12|^2")
        plt.plot(f / 1e9, np.abs(s[("o1", "o3")]) ** 2, label="|S13|^2")
        plt.plot(f / 1e9, np.abs(s[("o2", "o3")]) ** 2, label="|S23|^2")
        plt.xlabel("Frequency [GHz]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    f = jnp.asarray(f)
    f_flat = f.ravel()
    sdict = {(f"o{i}", f"o{i}"): jnp.full(f_flat.shape[0], -1 / 3) for i in range(1, 4)}
    sdict |= {
        (f"o{i}", f"o{j}"): jnp.full(f_flat.shape[0], 2 / 3)
        for i in range(1, 4)
        for j in range(i + 1, 4)
    }
    return sax.reciprocal({k: v.reshape(*f.shape) for k, v in sdict.items()})

terminator

terminator(*, wl: FloatArrayLike = WL_C, reflection: FloatArrayLike = 0.0) -> SDict

Optical terminator / absorber model.

A 1-port device that absorbs all incoming light with minimal reflection. Used to terminate unused ports and prevent back-reflections.

in0 o1 X

Parameters:

Name Type Description Default
wl FloatArrayLike

Wavelength in micrometers.

WL_C
reflection FloatArrayLike

Residual power reflection coefficient between 0 and 1. Defaults to 0.0 (perfect absorber).

0.0

Returns:

Type Description
SDict

S-matrix dictionary for the terminator.

Examples:

Ideal terminator:

# mkdocs: render
import matplotlib.pyplot as plt
import numpy as np
import sax

sax.set_port_naming_strategy("optical")

wl = sax.wl_c()
s = sax.models.terminator(wl=wl, reflection=0.01)
refl = np.abs(s[("o1", "o1")]) ** 2
plt.figure()
plt.plot(wl, refl, label="reflection")
plt.xlabel("Wavelength [μm]")
plt.ylabel("Power")
plt.legend()
Source code in src/sax/models/terminators.py
@jax.jit
@validate_call
def terminator(
    *,
    wl: sax.FloatArrayLike = sax.WL_C,
    reflection: sax.FloatArrayLike = 0.0,
) -> sax.SDict:
    """Optical terminator / absorber model.

    A 1-port device that absorbs all incoming light with minimal reflection.
    Used to terminate unused ports and prevent back-reflections.

    ```{svgbob}
    in0
     o1 =========X
    ```

    Args:
        wl: Wavelength in micrometers.
        reflection: Residual power reflection coefficient between 0 and 1.
            Defaults to 0.0 (perfect absorber).

    Returns:
        S-matrix dictionary for the terminator.

    Examples:
        Ideal terminator:

        ```python
        # mkdocs: render
        import matplotlib.pyplot as plt
        import numpy as np
        import sax

        sax.set_port_naming_strategy("optical")

        wl = sax.wl_c()
        s = sax.models.terminator(wl=wl, reflection=0.01)
        refl = np.abs(s[("o1", "o1")]) ** 2
        plt.figure()
        plt.plot(wl, refl, label="reflection")
        plt.xlabel("Wavelength [μm]")
        plt.ylabel("Power")
        plt.legend()
        ```
    """
    one = jnp.ones_like(jnp.asarray(wl))
    r = jnp.asarray(reflection**0.5, dtype=complex) * one
    return {
        ("o1", "o1"): r,
    }

transmission_line_s_params

transmission_line_s_params(
    gamma: ComplexLike,
    z0: ComplexLike,
    length: FloatArrayLike,
    z_ref: ComplexLike | None = None,
) -> tuple[Array, Array]

S-parameters of a uniform transmission line (ABCD→S conversion).

The ABCD matrix of a line with characteristic impedance \(Z_0\), propagation constant \(\gamma\), and length \(\ell\) is:

\[ \begin{aligned} \begin{pmatrix} A & B \\ C & D \end{pmatrix} = \begin{pmatrix} \cosh\theta & Z_0\sinh\theta \\ \sinh\theta / Z_0 & \cosh\theta \end{pmatrix}, \quad \theta = \gamma\ell \end{aligned} \]

Converting to S-parameters referenced to \(Z_{\mathrm{ref}}\)

\[ \begin{aligned} S_{11} &= \frac{A + B/Z_{\mathrm{ref}} - CZ_{\mathrm{ref}} - D}{ A + B/Z_{\mathrm{ref}} + CZ_{\mathrm{ref}} + D} \\ S_{21} &= \frac{2}{A + B/Z_{\mathrm{ref}} + CZ_{\mathrm{ref}} + D} \end{aligned} \]

When z_ref is None the reference impedance defaults to z0 (matched case), giving \(S_{11} = 0\) and \(S_{21} = e^{-\gamma\ell}\).

References

Pozar, Table 4.2

Parameters:

Name Type Description Default
gamma ComplexLike

Complex propagation constant (1/m).

required
z0 ComplexLike

Characteristic impedance (Ω).

required
length FloatArrayLike

Physical length (m).

required
z_ref ComplexLike | None

Reference (port) impedance (Ω). Defaults to z0.

None

Returns:

Type Description
tuple[Array, Array]

(S11, S21) — complex S-parameter arrays.

Source code in src/sax/models/rf.py
@partial(jax.jit, inline=True)
def transmission_line_s_params(
    gamma: sax.ComplexLike,
    z0: sax.ComplexLike,
    length: sax.FloatArrayLike,
    z_ref: sax.ComplexLike | None = None,
) -> tuple[jax.Array, jax.Array]:
    r"""S-parameters of a uniform transmission line (ABCD→S conversion).

    The ABCD matrix of a line with characteristic impedance $Z_0$,
    propagation constant $\gamma$, and length $\ell$ is:



    $$
    \begin{aligned}
        \begin{pmatrix} A & B \\ C & D \end{pmatrix}
    = \begin{pmatrix}
        \cosh\theta & Z_0\sinh\theta \\
        \sinh\theta / Z_0 & \cosh\theta
    \end{pmatrix}, \quad \theta = \gamma\ell
    \end{aligned}
    $$



    Converting to S-parameters referenced to $Z_{\mathrm{ref}}$



    $$
    \begin{aligned}
        S_{11} &= \frac{A + B/Z_{\mathrm{ref}} - CZ_{\mathrm{ref}} - D}{
            A + B/Z_{\mathrm{ref}} + CZ_{\mathrm{ref}} + D} \\
        S_{21} &= \frac{2}{A + B/Z_{\mathrm{ref}} + CZ_{\mathrm{ref}} + D}
    \end{aligned}
    $$




    When ``z_ref`` is ``None`` the reference impedance defaults to ``z0``
    (matched case), giving $S_{11} = 0$ and
    $S_{21} = e^{-\gamma\ell}$.

    References:
        Pozar, Table 4.2

    Args:
        gamma: Complex propagation constant (1/m).
        z0: Characteristic impedance (Ω).
        length: Physical length (m).
        z_ref: Reference (port) impedance (Ω).  Defaults to ``z0``.

    Returns:
        ``(S11, S21)`` — complex S-parameter arrays.
    """
    gamma_arr = jnp.asarray(gamma, dtype=complex)
    z0_arr = jnp.asarray(z0, dtype=complex)
    length_arr = jnp.asarray(length, dtype=float)

    if z_ref is None:
        z_ref = z0
    z_ref_arr = jnp.asarray(z_ref, dtype=complex)

    theta = gamma_arr * length_arr

    cosh_t = jnp.cosh(theta)
    sinh_t = jnp.sinh(theta)

    a = cosh_t
    b = z0_arr * sinh_t
    c = sinh_t / z0_arr

    denom = a + b / z_ref_arr + c * z_ref_arr + a
    s11 = (b / z_ref_arr - c * z_ref_arr) / denom
    s21 = 2.0 / denom

    return s11, s21

unitary

unitary(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> SCooModel

Generate a unitary N×M optical device model.

o ni 1 o ni . . . o2 in1 o N 2 o1 in0 o N 1

Parameters:

Name Type Description Default
num_inputs int

Number of input ports for the device.

required
num_outputs int

Number of output ports for the device.

required
reciprocal bool

If True, the device exhibits reciprocal behavior (S = S^T). This is typical for passive optical devices. Defaults to True.

True
diagonal bool

If True, creates a diagonal coupling matrix (each input couples to only one output). If False, creates full coupling between all input-output pairs. Defaults to False.

False

Returns:

Type Description
SCooModel

A unitary model

Source code in src/sax/models/factories.py
@validate_call
def unitary(
    num_inputs: int,
    num_outputs: int,
    *,
    reciprocal: bool = True,
    diagonal: bool = False,
) -> sax.SCooModel:
    """Generate a unitary N×M optical device model.

    ```{svgbob}
    o_ni -1 *---+-----+--* o_ni
                |     |
                  .
                  .
                  .
                |     |
         o2 *---+-----+--* o_N -2
        in1     |     |
                |     |
         o1 *---+-----+--* o_N -1
        in0
    ```

    Args:
        num_inputs: Number of input ports for the device.
        num_outputs: Number of output ports for the device.
        reciprocal: If True, the device exhibits reciprocal behavior (S = S^T).
            This is typical for passive optical devices. Defaults to True.
        diagonal: If True, creates a diagonal coupling matrix (each input
            couples to only one output). If False, creates full coupling
            between all input-output pairs. Defaults to False.

    Returns:
        A unitary model
    """
    # let's create the squared S-matrix:
    N = max(num_inputs, num_outputs)
    S = jnp.zeros((2 * N, 2 * N), dtype=float)

    if not diagonal:
        S = S.at[:N, N:].set(1)
    else:
        r = jnp.arange(
            N,
            dtype=int,
        )  # reciprocal only works if num_inputs == num_outputs!
        S = S.at[r, N + r].set(1)

    if reciprocal:
        if not diagonal:
            S = S.at[N:, :N].set(1)
        else:
            r = jnp.arange(
                N,
                dtype=int,
            )  # reciprocal only works if num_inputs == num_outputs!
            S = S.at[N + r, r].set(1)

    # Now we need to normalize the squared S-matrix
    U, s, V = jnp.linalg.svd(S, full_matrices=False)
    S = jnp.sqrt(U @ jnp.diag(jnp.where(s > sax.EPS, 1, 0)) @ V)

    # Now create subset of this matrix we're interested in:
    r = jnp.concatenate(
        [jnp.arange(num_inputs, dtype=int), N + jnp.arange(num_outputs, dtype=int)],
        0,
    )
    S = S[r, :][:, r]
    S = S.T  # convert from (from, to) to standard (to, from) convention

    # let's convert it in SCOO format:
    Si, Sj = jnp.where(S > sax.EPS)
    Sx = S[Si, Sj]

    # the last missing piece is a port map:
    p = sax.PortNamer(num_inputs, num_outputs)
    pm = {
        **{p[i]: i for i in range(num_inputs)},
        **{p[i + num_inputs]: i + num_inputs for i in range(num_outputs)},
    }

    @validate_call
    def func(*, wl: sax.FloatArrayLike = 1.5) -> sax.SCoo:
        wl_ = jnp.asarray(wl)
        Sx_ = jnp.broadcast_to(Sx, (*wl_.shape, *Sx.shape))
        return Si, Sj, Sx_, pm

    func.__name__ = f"unitary_{num_inputs}_{num_outputs}"
    func.__qualname__ = f"unitary_{num_inputs}_{num_outputs}"
    return jax.jit(func)