Skip to content

Multimode simulations

SAX can handle multiple modes too!

from itertools import combinations_with_replacement, product

import jax.numpy as jnp

import sax

Ports and modes per port

Let's denote a combination of a port and a mode by a string of the following format: "{port}@{mode}". We can obtain all possible port-mode combinations with some magic itertools functions:

ports = ["in0", "out0"]
modes = ["TE", "TM"]
portmodes = [
    (f"{p1}@{m1}", f"{p2}@{m2}")
    for (p1, m1), (p2, m2) in combinations_with_replacement(product(ports, modes), 2)
]
portmodes
[('in0@TE', 'in0@TE'),
 ('in0@TE', 'in0@TM'),
 ('in0@TE', 'out0@TE'),
 ('in0@TE', 'out0@TM'),
 ('in0@TM', 'in0@TM'),
 ('in0@TM', 'out0@TE'),
 ('in0@TM', 'out0@TM'),
 ('out0@TE', 'out0@TE'),
 ('out0@TE', 'out0@TM'),
 ('out0@TM', 'out0@TM')]

If we would disregard any backreflection, this can be further simplified:

portmodes_without_backreflection = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[0] != p2.split("@")[0]
]
portmodes_without_backreflection
[('in0@TE', 'out0@TE'),
 ('in0@TE', 'out0@TM'),
 ('in0@TM', 'out0@TE'),
 ('in0@TM', 'out0@TM')]

Sometimes cross-polarization terms can also be ignored:

portmodes_without_crosspolarization = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[1] == p2.split("@")[1]
]
portmodes_without_crosspolarization
[('in0@TE', 'in0@TE'),
 ('in0@TE', 'out0@TE'),
 ('in0@TM', 'in0@TM'),
 ('in0@TM', 'out0@TM'),
 ('out0@TE', 'out0@TE'),
 ('out0@TM', 'out0@TM')]

Multimode waveguide

Let's create a waveguide with two ports ("in", "out") and two modes ("te", "tm") without backreflection. Let's assume there is 5% cross-polarization and that the "tm"->"tm" transmission is 10% worse than the "te"->"te" transmission. Naturally in more realisic waveguide models these percentages will be length-dependent, but this is just a dummy model serving as an example.

def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):
    """A simple straight waveguide model

    Args:
        wl: wavelength
        neff: waveguide effective index
        ng: waveguide group index (used for linear neff dispersion)
        wl0: center wavelength at which neff is defined
        length: [m] wavelength length
        loss: [dB/m] waveguide loss
    """
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("in0@TE", "out0@TE"): 0.95 * transmission,  # 5% lost to cross-polarization
            ("in0@TE", "out0@TM"): 0.05 * transmission,  # 5% cross-polarization
            ("in0@TM", "out0@TM"): 0.85 * transmission,  # 10% worse tm->tm than te->te
            ("in0@TM", "out0@TE"): 0.05 * transmission,  # 5% cross-polarization
        }
    )
    return sdict


waveguide()
{('in0@TE',
  'out0@TE'): Array(0.77972527+0.5427048j, dtype=complex128, weak_type=True),
 ('in0@TE',
  'out0@TM'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True),
 ('in0@TM',
  'out0@TM'): Array(0.69764893+0.48557798j, dtype=complex128, weak_type=True),
 ('in0@TM',
  'out0@TE'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True),
 ('out0@TE',
  'in0@TE'): Array(0.77972527+0.5427048j, dtype=complex128, weak_type=True),
 ('out0@TM',
  'in0@TE'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True),
 ('out0@TM',
  'in0@TM'): Array(0.69764893+0.48557798j, dtype=complex128, weak_type=True),
 ('out0@TE',
  'in0@TM'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True)}

Multimode Coupler

def coupler():
    return {
        ("in0@TE", "out0@TE"): 0.45**0.5,
        ("in0@TE", "out1@TE"): 1j * 0.45**0.5,
        ("in1@TE", "out0@TE"): 1j * 0.45**0.5,
        ("in1@TE", "out1@TE"): 0.45**0.5,
        ("in0@TM", "out0@TM"): 0.45**0.5,
        ("in0@TM", "out1@TM"): 1j * 0.45**0.5,
        ("in1@TM", "out0@TM"): 1j * 0.45**0.5,
        ("in1@TM", "out1@TM"): 0.45**0.5,
        ("in0@TE", "out0@TM"): 0.01**0.5,
        ("in0@TE", "out1@TM"): 1j * 0.01**0.5,
        ("in1@TE", "out0@TM"): 1j * 0.01**0.5,
        ("in1@TE", "out1@TM"): 0.01**0.5,
        ("in0@TM", "out0@TE"): 0.01**0.5,
        ("in0@TM", "out1@TE"): 1j * 0.01**0.5,
        ("in1@TM", "out0@TE"): 1j * 0.01**0.5,
        ("in1@TM", "out1@TE"): 0.01**0.5,
    }

Multimode MZI

We can now combine these models into a circuit in much the same way as before. We just need to add the modes= keyword:

mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": "coupler",  # single mode models will be automatically converted to multimode models without cross polarization.
            "top": {"component": "straight", "settings": {"length": 25.0}},
            "btm": {"component": "straight", "settings": {"length": 15.0}},
            "rgt": "coupler",  # single mode models will be automatically converted to multimode models without cross polarization.
        },
        "connections": {
            "lft,out0": "btm,in0",
            "btm,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    },
    models={
        "coupler": coupler,
        "straight": waveguide,
    },
)
mzi()
{('in0@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'out0@TE'): Array(-0.24856154+0.09205708j, dtype=complex128),
 ('in0@TE', 'out0@TM'): Array(-0.08070811+0.029891j, dtype=complex128),
 ('in0@TE', 'out1@TE'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in0@TE', 'out1@TM'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in0@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'out0@TE'): Array(-0.08070811+0.029891j, dtype=complex128),
 ('in0@TM', 'out0@TM'): Array(-0.22385744+0.08290769j, dtype=complex128),
 ('in0@TM', 'out1@TE'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in0@TM', 'out1@TM'): Array(0.71348524-0.26424592j, dtype=complex128),
 ('in1@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'out0@TE'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in1@TE', 'out0@TM'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in1@TE', 'out1@TE'): Array(0.24856154-0.09205708j, dtype=complex128),
 ('in1@TE', 'out1@TM'): Array(0.08070811-0.029891j, dtype=complex128),
 ('in1@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'out0@TE'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in1@TM', 'out0@TM'): Array(0.71348524-0.26424592j, dtype=complex128),
 ('in1@TM', 'out1@TE'): Array(0.08070811-0.029891j, dtype=complex128),
 ('in1@TM', 'out1@TM'): Array(0.22385744-0.08290769j, dtype=complex128),
 ('out0@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out1@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out1@TM'): Array(0.+0.j, dtype=complex128)}

we can convert this model back to a singlemode SDict as follows:

mzi_te = sax.singlemode(mzi, mode="TE")
mzi_te()
{('in0', 'in0'): Array(0.+0.j, dtype=complex128),
 ('in0', 'in1'): Array(0.+0.j, dtype=complex128),
 ('in0', 'out0'): Array(-0.24856154+0.09205708j, dtype=complex128),
 ('in0', 'out1'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in1', 'in0'): Array(0.+0.j, dtype=complex128),
 ('in1', 'in1'): Array(0.+0.j, dtype=complex128),
 ('in1', 'out0'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in1', 'out1'): Array(0.24856154-0.09205708j, dtype=complex128),
 ('out0', 'in0'): Array(0.+0.j, dtype=complex128),
 ('out0', 'in1'): Array(0.+0.j, dtype=complex128),
 ('out0', 'out0'): Array(0.+0.j, dtype=complex128),
 ('out0', 'out1'): Array(0.+0.j, dtype=complex128),
 ('out1', 'in0'): Array(0.+0.j, dtype=complex128),
 ('out1', 'in1'): Array(0.+0.j, dtype=complex128),
 ('out1', 'out0'): Array(0.+0.j, dtype=complex128),
 ('out1', 'out1'): Array(0.+0.j, dtype=complex128)}

Model Conversion: Singlemode ↔ Multimode

SAX provides utilities to convert between singlemode and multimode representations. This is useful when you want to: - Use a singlemode model in a multimode circuit - Extract a specific mode from a multimode result - Compare singlemode and multimode simulations

Converting Singlemode Models to Multimode

When you use a singlemode model in a multimode circuit, SAX automatically converts it. However, you can also do this explicitly using sax.multimode():

# Define a simple singlemode waveguide
def simple_waveguide_singlemode(wl=1.55, length=10.0, neff=2.4):
    phase = 2 * jnp.pi * neff * length / wl
    return sax.reciprocal({("in0", "out0"): jnp.exp(1j * phase)})


# The original singlemode result
sm_result = simple_waveguide_singlemode()
print("Singlemode result:")
print(sm_result)

# Convert to multimode (adds @TE and @TM ports without cross-polarization)
wg_multimode = sax.multimode(simple_waveguide_singlemode, modes=("TE", "TM"))
mm_result = wg_multimode()
print("\nMultimode result (no cross-polarization):")
print(mm_result)
Singlemode result:
{('in0', 'out0'): Array(-0.99486932+0.10116832j, dtype=complex128, weak_type=True), ('out0', 'in0'): Array(-0.99486932+0.10116832j, dtype=complex128, weak_type=True)}

Multimode result (no cross-polarization):
{('in0@TE', 'out0@TE'): Array(-0.99486932+0.10116832j, dtype=complex128), ('in0@TM', 'out0@TM'): Array(-0.99486932+0.10116832j, dtype=complex128), ('out0@TE', 'in0@TE'): Array(-0.99486932+0.10116832j, dtype=complex128), ('out0@TM', 'in0@TM'): Array(-0.99486932+0.10116832j, dtype=complex128)}

Extracting Singlemode from Multimode

You can extract a specific mode from a multimode model or result using sax.singlemode():

# Extract TE mode from the multimode MZI we created earlier
mzi_te_only = sax.singlemode(mzi, mode="TE")
print("MZI TE-only result:")
print(mzi_te_only())

# Extract TM mode
mzi_tm_only = sax.singlemode(mzi, mode="TM")
print("\nMZI TM-only result:")
print(mzi_tm_only())
MZI TE-only result:
{('in0', 'in0'): Array(0.+0.j, dtype=complex128), ('in0', 'in1'): Array(0.+0.j, dtype=complex128), ('in0', 'out0'): Array(-0.24856154+0.09205708j, dtype=complex128), ('in0', 'out1'): Array(0.7922229-0.29340714j, dtype=complex128), ('in1', 'in0'): Array(0.+0.j, dtype=complex128), ('in1', 'in1'): Array(0.+0.j, dtype=complex128), ('in1', 'out0'): Array(0.7922229-0.29340714j, dtype=complex128), ('in1', 'out1'): Array(0.24856154-0.09205708j, dtype=complex128), ('out0', 'in0'): Array(0.+0.j, dtype=complex128), ('out0', 'in1'): Array(0.+0.j, dtype=complex128), ('out0', 'out0'): Array(0.+0.j, dtype=complex128), ('out0', 'out1'): Array(0.+0.j, dtype=complex128), ('out1', 'in0'): Array(0.+0.j, dtype=complex128), ('out1', 'in1'): Array(0.+0.j, dtype=complex128), ('out1', 'out0'): Array(0.+0.j, dtype=complex128), ('out1', 'out1'): Array(0.+0.j, dtype=complex128)}

MZI TM-only result:


{('in0', 'in0'): Array(0.+0.j, dtype=complex128), ('in0', 'in1'): Array(0.+0.j, dtype=complex128), ('in0', 'out0'): Array(-0.22385744+0.08290769j, dtype=complex128), ('in0', 'out1'): Array(0.71348524-0.26424592j, dtype=complex128), ('in1', 'in0'): Array(0.+0.j, dtype=complex128), ('in1', 'in1'): Array(0.+0.j, dtype=complex128), ('in1', 'out0'): Array(0.71348524-0.26424592j, dtype=complex128), ('in1', 'out1'): Array(0.22385744-0.08290769j, dtype=complex128), ('out0', 'in0'): Array(0.+0.j, dtype=complex128), ('out0', 'in1'): Array(0.+0.j, dtype=complex128), ('out0', 'out0'): Array(0.+0.j, dtype=complex128), ('out0', 'out1'): Array(0.+0.j, dtype=complex128), ('out1', 'in0'): Array(0.+0.j, dtype=complex128), ('out1', 'in1'): Array(0.+0.j, dtype=complex128), ('out1', 'out0'): Array(0.+0.j, dtype=complex128), ('out1', 'out1'): Array(0.+0.j, dtype=complex128)}

Using Singlemode Models in Multimode Circuits

# A singlemode coupler (no cross-polarization)
def coupler_singlemode(coupling=0.5):
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    return sax.reciprocal(
        {
            ("in0", "out0"): tau,
            ("in0", "out1"): 1j * kappa,
            ("in1", "out0"): 1j * kappa,
            ("in1", "out1"): tau,
        }
    )


# Mix singlemode coupler with multimode waveguide in a circuit
mixed_mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": "coupler_sm",  # singlemode - will be auto-converted
            "top": {
                "component": "straight_mm",
                "settings": {"length": 25.0},
            },  # multimode
            "btm": {
                "component": "straight_mm",
                "settings": {"length": 15.0},
            },  # multimode
            "rgt": "coupler_sm",  # singlemode - will be auto-converted
        },
        "connections": {
            "lft,out0": "btm,in0",
            "btm,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    },
    models={
        "coupler_sm": coupler_singlemode,  # singlemode model
        "straight_mm": sax.singlemode(
            waveguide, mode="TE"
        ),  # multimode model (defined earlier)
    },
)

print("Mixed circuit result (singlemode couplers + multimode waveguides):")
result = mixed_mzi()
print(f"Number of port pairs: {len(result)}")
Mixed circuit result (singlemode couplers + multimode waveguides):


Number of port pairs: 16

Summary: Conversion Functions

Function Purpose Example
sax.multimode(model, modes) Convert singlemode model → multimode sax.multimode(waveguide, modes=("TE", "TM"))
sax.singlemode(model, mode) Extract single mode from multimode sax.singlemode(mzi, mode="TE")

These functions work on both model functions and S-parameter dictionaries (SDict), making it easy to switch between representations at any point in your workflow.