Multimode simulations#

SAX can handle multiple modes too!

from itertools import combinations_with_replacement, product

import jax.numpy as jnp
import sax

Ports and modes per port#

Let’s denote a combination of a port and a mode by a string of the following format: "{port}@{mode}". We can obtain all possible port-mode combinations with some magic itertools functions:

ports = ["in0", "out0"]
modes = ["TE", "TM"]
portmodes = [
    (f"{p1}@{m1}", f"{p2}@{m2}")
    for (p1, m1), (p2, m2) in combinations_with_replacement(product(ports, modes), 2)
]
portmodes
[('in0@TE', 'in0@TE'),
 ('in0@TE', 'in0@TM'),
 ('in0@TE', 'out0@TE'),
 ('in0@TE', 'out0@TM'),
 ('in0@TM', 'in0@TM'),
 ('in0@TM', 'out0@TE'),
 ('in0@TM', 'out0@TM'),
 ('out0@TE', 'out0@TE'),
 ('out0@TE', 'out0@TM'),
 ('out0@TM', 'out0@TM')]

If we would disregard any backreflection, this can be further simplified:

portmodes_without_backreflection = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[0] != p2.split("@")[0]
]
portmodes_without_backreflection
[('in0@TE', 'out0@TE'),
 ('in0@TE', 'out0@TM'),
 ('in0@TM', 'out0@TE'),
 ('in0@TM', 'out0@TM')]

Sometimes cross-polarization terms can also be ignored:

portmodes_without_crosspolarization = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[1] == p2.split("@")[1]
]
portmodes_without_crosspolarization
[('in0@TE', 'in0@TE'),
 ('in0@TE', 'out0@TE'),
 ('in0@TM', 'in0@TM'),
 ('in0@TM', 'out0@TM'),
 ('out0@TE', 'out0@TE'),
 ('out0@TM', 'out0@TM')]

Multimode waveguide#

Let’s create a waveguide with two ports ("in", "out") and two modes ("te", "tm") without backreflection. Let’s assume there is 5% cross-polarization and that the "tm"->"tm" transmission is 10% worse than the "te"->"te" transmission. Naturally in more realisic waveguide models these percentages will be length-dependent, but this is just a dummy model serving as an example.

def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):
    """a simple straight waveguide model

    Args:
        wl: wavelength
        neff: waveguide effective index
        ng: waveguide group index (used for linear neff dispersion)
        wl0: center wavelength at which neff is defined
        length: [m] wavelength length
        loss: [dB/m] waveguide loss
    """
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("in0@TE", "out0@TE"): 0.95 * transmission,  # 5% lost to cross-polarization
            ("in0@TE", "out0@TM"): 0.05 * transmission,  # 5% cross-polarization
            ("in0@TM", "out0@TM"): 0.85 * transmission,  # 10% worse tm->tm than te->te
            ("in0@TM", "out0@TE"): 0.05 * transmission,  # 5% cross-polarization
        }
    )
    return sdict


waveguide()
{('in0@TE',
  'out0@TE'): Array(0.77972527+0.5427048j, dtype=complex128, weak_type=True),
 ('in0@TE',
  'out0@TM'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True),
 ('in0@TM',
  'out0@TM'): Array(0.69764893+0.48557798j, dtype=complex128, weak_type=True),
 ('in0@TM',
  'out0@TE'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True),
 ('out0@TE',
  'in0@TE'): Array(0.77972527+0.5427048j, dtype=complex128, weak_type=True),
 ('out0@TM',
  'in0@TE'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True),
 ('out0@TM',
  'in0@TM'): Array(0.69764893+0.48557798j, dtype=complex128, weak_type=True),
 ('out0@TE',
  'in0@TM'): Array(0.04103817+0.02856341j, dtype=complex128, weak_type=True)}

Multimode Coupler#

def coupler():
    return {
        ("in0@TE", "out0@TE"): 0.45**0.5,
        ("in0@TE", "out1@TE"): 1j * 0.45**0.5,
        ("in1@TE", "out0@TE"): 1j * 0.45**0.5,
        ("in1@TE", "out1@TE"): 0.45**0.5,
        ("in0@TM", "out0@TM"): 0.45**0.5,
        ("in0@TM", "out1@TM"): 1j * 0.45**0.5,
        ("in1@TM", "out0@TM"): 1j * 0.45**0.5,
        ("in1@TM", "out1@TM"): 0.45**0.5,
        ("in0@TE", "out0@TM"): 0.01**0.5,
        ("in0@TE", "out1@TM"): 1j * 0.01**0.5,
        ("in1@TE", "out0@TM"): 1j * 0.01**0.5,
        ("in1@TE", "out1@TM"): 0.01**0.5,
        ("in0@TM", "out0@TE"): 0.01**0.5,
        ("in0@TM", "out1@TE"): 1j * 0.01**0.5,
        ("in1@TM", "out0@TE"): 1j * 0.01**0.5,
        ("in1@TM", "out1@TE"): 0.01**0.5,
    }

Multimode MZI#

We can now combine these models into a circuit in much the same way as before. We just need to add the modes= keyword:

mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": "coupler",  # single mode models will be automatically converted to multimode models without cross polarization.
            "top": {"component": "straight", "settings": {"length": 25.0}},
            "btm": {"component": "straight", "settings": {"length": 15.0}},
            "rgt": "coupler",  # single mode models will be automatically converted to multimode models without cross polarization.
        },
        "connections": {
            "lft,out0": "btm,in0",
            "btm,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    },
    models={
        "coupler": coupler,
        "straight": waveguide,
    },
)
mzi()
{('in0@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TE', 'out0@TE'): Array(-0.24856154+0.09205708j, dtype=complex128),
 ('in0@TE', 'out0@TM'): Array(-0.08070811+0.029891j, dtype=complex128),
 ('in0@TE', 'out1@TE'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in0@TE', 'out1@TM'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in0@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in0@TM', 'out0@TE'): Array(-0.08070811+0.029891j, dtype=complex128),
 ('in0@TM', 'out0@TM'): Array(-0.22385744+0.08290769j, dtype=complex128),
 ('in0@TM', 'out1@TE'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in0@TM', 'out1@TM'): Array(0.71348524-0.26424592j, dtype=complex128),
 ('in1@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TE', 'out0@TE'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in1@TE', 'out0@TM'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in1@TE', 'out1@TE'): Array(0.24856154-0.09205708j, dtype=complex128),
 ('in1@TE', 'out1@TM'): Array(0.08070811-0.029891j, dtype=complex128),
 ('in1@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('in1@TM', 'out0@TE'): Array(0.25723534-0.09526951j, dtype=complex128),
 ('in1@TM', 'out0@TM'): Array(0.71348524-0.26424592j, dtype=complex128),
 ('in1@TM', 'out1@TE'): Array(0.08070811-0.029891j, dtype=complex128),
 ('in1@TM', 'out1@TM'): Array(0.22385744-0.08290769j, dtype=complex128),
 ('out0@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TE', 'out1@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out0@TM', 'out1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TE', 'out1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'in1@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out0@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out0@TM'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out1@TE'): Array(0.+0.j, dtype=complex128),
 ('out1@TM', 'out1@TM'): Array(0.+0.j, dtype=complex128)}

we can convert this model back to a singlemode SDict as follows:

mzi_te = sax.singlemode(mzi, mode="TE")
mzi_te()
{('in0', 'in0'): Array(0.+0.j, dtype=complex128),
 ('in0', 'in1'): Array(0.+0.j, dtype=complex128),
 ('in0', 'out0'): Array(-0.24856154+0.09205708j, dtype=complex128),
 ('in0', 'out1'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in1', 'in0'): Array(0.+0.j, dtype=complex128),
 ('in1', 'in1'): Array(0.+0.j, dtype=complex128),
 ('in1', 'out0'): Array(0.7922229-0.29340714j, dtype=complex128),
 ('in1', 'out1'): Array(0.24856154-0.09205708j, dtype=complex128),
 ('out0', 'in0'): Array(0.+0.j, dtype=complex128),
 ('out0', 'in1'): Array(0.+0.j, dtype=complex128),
 ('out0', 'out0'): Array(0.+0.j, dtype=complex128),
 ('out0', 'out1'): Array(0.+0.j, dtype=complex128),
 ('out1', 'in0'): Array(0.+0.j, dtype=complex128),
 ('out1', 'in1'): Array(0.+0.j, dtype=complex128),
 ('out1', 'out0'): Array(0.+0.j, dtype=complex128),
 ('out1', 'out1'): Array(0.+0.j, dtype=complex128)}