Simulating an All-Pass Filter#

A simple comparison between an analytical evaluation of an all pass filter and using SAX.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax

Schematic#


           in0---out0
        in1          out1
           \        /
            ========
           /        \
 in0 <- in0          out0 -> out0

Simulation & Design Parameters#

loss = 0.1  # [dB/μm] (alpha) waveguide loss
neff = 2.34  # Effective index of the waveguides
ng = 3.4  # Group index of the waveguides
wl0 = 1.55  # [μm] the wavelength at which neff and ng are defined
ring_length = 10.0  # [μm] Length of the ring
coupling = 0.5  # [] coupling of the coupler
wl = jnp.linspace(1.5, 1.6, 1000)  # [μm] Wavelengths to sweep over

Frequency Domain Analytically#

As a comparison, we first calculate the frequency domain response for the all-pass filter analytically:

\[\begin{align*} o = \frac{t-10^{-\alpha L/20}\exp(2\pi j n_{\rm eff}(\lambda) L / \lambda)}{1-t10^{-\alpha L/20}\exp(2\pi j n_{\rm eff}(\lambda) L / \lambda)}s \end{align*}\]
def all_pass_analytical():
    """Analytic Frequency Domain Response of an all pass filter"""
    detected = jnp.zeros_like(wl)
    transmission = 1 - coupling
    neff_wl = (
        neff + (wl0 - wl) * (ng - neff) / wl0
    )  # we expect a linear behavior with respect to wavelength
    out = jnp.sqrt(transmission) - 10 ** (-loss * ring_length / 20.0) * jnp.exp(
        2j * jnp.pi * neff_wl * ring_length / wl
    )
    out /= 1 - jnp.sqrt(transmission) * 10 ** (-loss * ring_length / 20.0) * jnp.exp(
        2j * jnp.pi * neff_wl * ring_length / wl
    )
    detected = abs(out) ** 2
    return detected
%time detected = all_pass_analytical() # non-jitted evaluation time
all_pass_analytical_jitted = jax.jit(all_pass_analytical)
%time detected = all_pass_analytical_jitted() # time to jit
%time detected = all_pass_analytical_jitted() # evaluation time after jitting

plt.plot(wl * 1e3, detected)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.show()
CPU times: user 236 ms, sys: 20.4 ms, total: 257 ms
Wall time: 254 ms
CPU times: user 20.8 ms, sys: 135 μs, total: 20.9 ms
Wall time: 19.2 ms
CPU times: user 23 μs, sys: 2 μs, total: 25 μs
Wall time: 27.7 μs
../_images/a2d5e210bfc57dd980bfe6f92017ed25f32b3e639e3712101a9f46c5f85f081a.png

Scatter Dictionaries#

all_pass_sax, _ = sax.circuit(
    netlist={
        "instances": {
            "dc": {"component": "coupler", "settings": {"coupling": coupling}},
            "top": {
                "component": "straight",
                "settings": {
                    "length": ring_length,
                    "loss": loss,
                    "neff": neff,
                    "ng": ng,
                    "wl0": wl0,
                    "wl": wl,
                },
            },
        },
        "connections": {
            "dc,out1": "top,in0",
            "top,out0": "dc,in1",
        },
        "ports": {
            "in0": "dc,in0",
            "out0": "dc,out0",
        },
    },
    models={
        "coupler": sax.models.coupler,
        "straight": sax.models.straight,
    },
)
%time detected_sax = all_pass_sax() # non-jitted evaluation time
all_pass_sax_jitted = jax.jit(all_pass_analytical)
%time detected_sax = all_pass_sax_jitted() # time to jit
%time detected_sax = all_pass_sax_jitted() # time after jitting

plt.plot(wl * 1e3, detected, label="analytical")
plt.plot(wl * 1e3, detected_sax, label="sax", ls="--", lw=3)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.show()
CPU times: user 350 ms, sys: 7.79 ms, total: 358 ms
Wall time: 316 ms
CPU times: user 202 μs, sys: 13 μs, total: 215 μs
Wall time: 194 μs
CPU times: user 24 μs, sys: 0 ns, total: 24 μs
Wall time: 26.5 μs
../_images/138767d4be246a02ba01448a3cd2845d6f799c4a1dfc3d967e349f6afe863009.png