Thin film optimization#

Let’s optimize a thin film…

This notebook was contributed by simbilod.

import jax
import jax.example_libraries.optimizers as opt
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax
import tqdm.notebook as tqdm
from tmm import coh_tmm

In this notebook, we apply SAX to thin-film optimization and show how it can be used for wavelength-dependent parameter optimization.

The language of transfer/scatter matrices is commonly used to calculate optical properties of thin-films. Many specialized methods exist for their optimization. However, SAX can be useful to cut down on developer time by circumventing the need to manually take gradients of complicated or often-changed objective functions, and by generating efficient code from simple syntax.

Dielectric mirror Fabry-Pérot#

Consider a stack composed of only two materials, \(n_A\) and \(n_B\). Two types of transfer matrices characterize wave propagation in the system : interfaces described by Fresnel’s equations, and propagation.

For the two-material stack, this leads to 4 scatter matrices coefficients. Through reciprocity they can be constructed out of two independent ones :

def fresnel_mirror_ij(ni=1.0, nj=1.0):
    """Model a (fresnel) interface between twoo refractive indices

    Args:
        ni: refractive index of the initial medium
        nf: refractive index of the final
    """
    r_fresnel_ij = (ni - nj) / (ni + nj)  # i->j reflection
    t_fresnel_ij = 2 * ni / (ni + nj)  # i->j transmission
    r_fresnel_ji = -r_fresnel_ij  # j -> i reflection
    t_fresnel_ji = (1 - r_fresnel_ij**2) / t_fresnel_ij  # j -> i transmission
    sdict = {
        ("in", "in"): r_fresnel_ij,
        ("in", "out"): t_fresnel_ij,
        ("out", "in"): t_fresnel_ji,
        ("out", "out"): r_fresnel_ji,
    }
    return sdict


def propagation_i(ni=1.0, di=0.5, wl=0.532):
    """Model the phase shift acquired as a wave propagates through medium A

    Args:
        ni: refractive index of medium (at wavelength wl)
        di: [μm] thickness of layer
        wl: [μm] wavelength
    """
    prop_i = jnp.exp(1j * 2 * jnp.pi * ni * di / wl)
    sdict = {
        ("in", "out"): prop_i,
        ("out", "in"): prop_i,
    }
    return sdict

A resonant cavity can be formed when a high index region is surrounded by low-index region :

dielectric_fabry_perot, _ = sax.circuit(
    netlist={
        "instances": {
            "air_B": fresnel_mirror_ij,
            "B": propagation_i,
            "B_air": fresnel_mirror_ij,
        },
        "connections": {
            "air_B,out": "B,in",
            "B,out": "B_air,in",
        },
        "ports": {
            "in": "air_B,in",
            "out": "B_air,out",
        },
    },
    backend="fg",
)

settings = sax.get_settings(dielectric_fabry_perot)
settings
{'air_B': {'ni': 1.0, 'nj': 1.0},
 'B': {'ni': 1.0, 'di': 0.5, 'wl': 0.532},
 'B_air': {'ni': 1.0, 'nj': 1.0}}

Let’s choose \(n_A = 1\), \(n_B = 2\), \(d_B = 1000\) nm, and compute over the visible spectrum :

settings = sax.copy_settings(settings)
settings["air_B"]["nj"] = 2.0
settings["B"]["ni"] = 2.0
settings["B_air"]["ni"] = 2.0

wls = jnp.linspace(0.380, 0.750, 200)
settings = sax.update_settings(settings, wl=wls)

Compute transmission and reflection, and compare to another package’s results (https://github.com/sbyrnes321/tmm) :

%%time
sdict = dielectric_fabry_perot(**settings)

transmitted = sdict["in", "out"]
reflected = sdict["in", "in"]
CPU times: user 318 ms, sys: 3.96 ms, total: 322 ms
Wall time: 319 ms
%%time

# tmm syntax (https://github.com/sbyrnes321/tmm)
d_list = [jnp.inf, 0.500, jnp.inf]
n_list = [1, 2, 1]
# initialize lists of y-values to plot
rnorm = []
tnorm = []
Tnorm = []
Rnorm = []
for l in wls:
    rnorm.append(coh_tmm("s", n_list, d_list, 0, l)["r"])
    tnorm.append(coh_tmm("s", n_list, d_list, 0, l)["t"])
    Tnorm.append(coh_tmm("s", n_list, d_list, 0, l)["T"])
    Rnorm.append(coh_tmm("s", n_list, d_list, 0, l)["R"])
CPU times: user 765 ms, sys: 7.72 ms, total: 773 ms
Wall time: 646 ms
plt.scatter(wls * 1e3, jnp.real(transmitted), label="t SAX")
plt.plot(wls * 1e3, jnp.real(jnp.array(tnorm)), "k", label="t tmm")
plt.scatter(wls * 1e3, jnp.real(reflected), label="r SAX")
plt.plot(wls * 1e3, jnp.real(jnp.array(rnorm)), "k--", label="r tmm")
plt.xlabel("λ [nm]")
plt.ylabel("Transmitted and reflected amplitude")
plt.legend(loc="upper right")
plt.title("Real part")
plt.show()

plt.scatter(wls * 1e3, jnp.imag(transmitted), label="t SAX")
plt.plot(wls * 1e3, jnp.imag(jnp.array(tnorm)), "k", label="t tmm")
plt.scatter(wls * 1e3, jnp.imag(reflected), label="r SAX")
plt.plot(wls * 1e3, jnp.imag(jnp.array(rnorm)), "k--", label="r tmm")
plt.xlabel("λ [nm]")
plt.ylabel("Transmitted and reflected amplitude")
plt.legend(loc="upper right")
plt.title("Imaginary part")
plt.show()
../_images/9b30ea107176690eef298b7a1c6e184b440e2507682885190e521968f4d8c25a.png ../_images/08d5f606fe7149f6eea7eaf1cc858d901df82d434b6dc35649a892578cbdbe56.png

In terms of powers, we get the following. Due to the reflections at the interfaces, resonant behaviour is observed, with evenly-spaced maxima/minima in wavevector space :

plt.scatter(2 * jnp.pi / wls, jnp.abs(transmitted) ** 2, label="T SAX")
plt.plot(2 * jnp.pi / wls, Tnorm, "k", label="T tmm")
plt.scatter(2 * jnp.pi / wls, jnp.abs(reflected) ** 2, label="R SAX")
plt.plot(2 * jnp.pi / wls, Rnorm, "k--", label="R tmm")
plt.vlines(
    jnp.arange(3, 6) * jnp.pi / (2 * 0.5),
    ymin=0,
    ymax=1,
    color="k",
    linestyle="--",
    label="m$\pi$/nd",
)
plt.xlabel("k = 2$\pi$/λ [1/nm]")
plt.ylabel("Transmitted and reflected intensities")
plt.legend(loc="upper right")
plt.show()
../_images/53e2048d4df3fb1f0944e2cdc22d2fb9e50a94957ceca3a58ebc19c286f5e8e7.png

Optimization test#

Let’s attempt to minimize transmission at 500 nm by varying thickness.

@jax.jit
def loss(thickness):
    settings = sax.update_settings(sax.get_settings(dielectric_fabry_perot), wl=0.5)
    settings["B"]["di"] = thickness
    settings["air_B"]["nj"] = 2.0
    settings["B"]["ni"] = 2.0
    settings["B_air"]["ni"] = 2.0
    sdict = dielectric_fabry_perot(**settings)
    return jnp.abs(sdict["in", "out"]) ** 2
grad = jax.jit(jax.grad(loss))
initial_thickness = 0.5
optim_init, optim_update, optim_params = opt.adam(step_size=0.01)
optim_state = optim_init(initial_thickness)


def train_step(step, optim_state):
    thickness = optim_params(optim_state)
    lossvalue = loss(thickness)
    gradvalue = grad(thickness)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state
range_ = tqdm.trange(100)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")
thickness = optim_params(optim_state)
thickness
Array(0.43726699, dtype=float64)
settings = sax.update_settings(sax.get_settings(dielectric_fabry_perot), wl=wls)
settings["B"]["di"] = thickness
settings["air_B"]["nj"] = 2.0
settings["B"]["ni"] = 2.0
settings["B_air"]["ni"] = 2.0
sdict = dielectric_fabry_perot(**settings)
detected = sdict["in", "out"]

plt.plot(wls * 1e3, jnp.abs(transmitted) ** 2, label="Before (500 nm)")
plt.plot(wls * 1e3, jnp.abs(detected) ** 2, label=f"After ({thickness*1e3:.0f} nm)")
plt.vlines(0.5 * 1e3, 0.6, 1, "k", linestyle="--")
plt.xlabel("λ [nm]")
plt.ylabel("Transmitted intensity")
plt.legend(loc="lower right")
plt.title("Thickness optimization")
plt.show()
../_images/2917b197ed36ef117ff664c8be73ff94ab968487c8bacef0fe819364fe5aea5c.png

General Fabry-Pérot étalon#

We reuse the propagation matrix above, and instead of simple interface matrices, model Fabry-Pérot mirrors as general lossless reciprocal scatter matrices :

\[\begin{split} \left(\begin{array}{c} E_t \\ E_r \end{array}\right) = E_{out} = SE_{in} = \left(\begin{array}{cc} t & r \\ r & t \end{array}\right) \left(\begin{array}{c} E_0 \\ 0 \end{array}\right) \end{split}\]

For lossless reciprocal systems, we further have the requirements

\[ |t|^2 + |r|^2 = 1 \]

and

\[ \angle t - \angle r = \pm \pi/2 \]

The general Fabry-Pérot cavity is analytically described by :

def airy_t13(t12, t23, r21, r23, wl, d=1.0, n=1.0):
    """General Fabry-Pérot transmission transfer function (Airy formula)

    Args:
        t12 and r12 : S-parameters of the first mirror
        t23 and r23 : S-parameters of the second mirror
        wl : wavelength
        d : gap between the two mirrors (in units of wavelength)
        n : index of the gap between the two mirrors

    Returns:
        t13 : complex transmission amplitude of the mirror-gap-mirror system

    Note:
        Each mirror is assumed to be lossless and reciprocal : tij = tji, rij = rji
    """
    phi = n * 2 * jnp.pi * d / wl
    return t12 * t23 * jnp.exp(-1j * phi) / (1 - r21 * r23 * jnp.exp(-2j * phi))


def airy_r13(t12, t23, r21, r23, wl, d=1.0, n=1.0):
    """General Fabry-Pérot reflection transfer function (Airy formula)

    Args:
        t12 and r12 : S-parameters of the first mirror
        t23 and r23 : S-parameters of the second mirror
        wl : wavelength
        d : gap between the two mirrors (in units of wavelength)
        n : index of the gap between the two mirrors

    Returns:
        r13 : complex reflection amplitude of the mirror-gap-mirror system

    Note:
        Each mirror is assumed to be lossless and reciprocal : tij = tji, rij = rji
    """
    phi = n * 2 * jnp.pi * d / wl
    return r21 + t12 * t12 * r23 * jnp.exp(-2j * phi) / (
        1 - r21 * r23 * jnp.exp(-2j * phi)
    )

We need to implement the relationship between \(t\) and \(r\) for lossless reciprocal mirrors. The design parameter will be the amplitude and phase of the tranmission coefficient. The reflection coefficient is then fully determined :

def t_complex(t_amp, t_ang):
    return t_amp * jnp.exp(-1j * t_ang)


def r_complex(t_amp, t_ang):
    r_amp = jnp.sqrt((1.0 - t_amp**2))
    r_ang = t_ang - jnp.pi / 2
    return r_amp * jnp.exp(-1j * r_ang)

Let’s see the expected result for half-mirrors :

t_initial = jnp.sqrt(0.5)
d_gap = 2.0
n_gap = 1.0
r_initial = r_complex(t_initial, 0.0)

wls = jnp.linspace(0.38, 0.78, 500)

T_analytical_initial = (
    jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))
    ** 2
)
R_analytical_initial = (
    jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))
    ** 2
)

plt.title(f"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}")
plt.plot(2 * jnp.pi / wls, T_analytical_initial, label="T")
plt.plot(2 * jnp.pi / wls, R_analytical_initial, label="R")
plt.vlines(
    jnp.arange(6, 11) * jnp.pi / 2.0,
    ymin=0,
    ymax=1,
    color="k",
    linestyle="--",
    label="m$\pi$/nd",
)
plt.xlabel("k = 2$\pi$/$\lambda$ [1/nm]")
plt.ylabel("Power (units of input)")
plt.legend()
plt.show()

plt.title(f"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}")
plt.plot(wls * 1e3, T_analytical_initial, label="T")
plt.plot(wls * 1e3, R_analytical_initial, label="R")
plt.xlabel("$\lambda$ (nm)")
plt.ylabel("Power (units of input)")
plt.legend()
plt.show()
../_images/870dbefdeaa5f3ce396c96308bafd392dca922379ebe50e501884308632a484c.png ../_images/8cc713d8aae4428edc011d7506a994e0315500bbc04e83c661daf4d4f60a429e.png
# Is power conserved? (to within 0.1%)
assert jnp.isclose(R_analytical_initial + T_analytical_initial, 1, 0.001).all()

Now let’s do the same with SAX by defining new elements :

def mirror(t_amp=0.5**0.5, t_ang=0.0):
    r_complex_val = r_complex(t_amp, t_ang)
    t_complex_val = t_complex(t_amp, t_ang)
    sdict = {
        ("in", "in"): r_complex_val,
        ("in", "out"): t_complex_val,
        ("out", "in"): t_complex_val,  # (1 - r_complex_val**2)/t_complex_val, # t_ji
        ("out", "out"): r_complex_val,  # -r_complex_val, # r_ji
    }
    return sdict


fabry_perot_tunable, _ = sax.circuit(
    netlist={
        "instances": {
            "mirror1": mirror,
            "gap": propagation_i,
            "mirror2": mirror,
        },
        "connections": {
            "mirror1,out": "gap,in",
            "gap,out": "mirror2,in",
        },
        "ports": {
            "in": "mirror1,in",
            "out": "mirror2,out",
        },
    },
    backend="fg",
)

settings = sax.get_settings(fabry_perot_tunable)
settings
{'mirror1': {'t_amp': 0.7071067811865476, 't_ang': 0.0},
 'gap': {'ni': 1.0, 'di': 0.5, 'wl': 0.532},
 'mirror2': {'t_amp': 0.7071067811865476, 't_ang': 0.0}}
fabry_perot_tunable, _ = sax.circuit(
    netlist={
        "instances": {
            "mirror1": mirror,
            "gap": propagation_i,
            "mirror2": mirror,
        },
        "connections": {
            "mirror1,out": "gap,in",
            "gap,out": "mirror2,in",
        },
        "ports": {
            "in": "mirror1,in",
            "out": "mirror2,out",
        },
    },
    backend="fg",
)

settings = sax.get_settings(fabry_perot_tunable)
settings
{'mirror1': {'t_amp': 0.7071067811865476, 't_ang': 0.0},
 'gap': {'ni': 1.0, 'di': 0.5, 'wl': 0.532},
 'mirror2': {'t_amp': 0.7071067811865476, 't_ang': 0.0}}
N = 100
wls = jnp.linspace(0.38, 0.78, N)
settings = sax.get_settings(fabry_perot_tunable)
settings = sax.update_settings(settings, wl=wls, t_amp=jnp.sqrt(0.5), t_ang=0.0)
settings["gap"]["ni"] = 1.0
settings["gap"]["di"] = 2.0
transmitted_initial = fabry_perot_tunable(**settings)["in", "out"]
reflected_initial = fabry_perot_tunable(**settings)["out", "out"]
T_analytical_initial = (
    jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))
    ** 2
)
R_analytical_initial = (
    jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))
    ** 2
)
plt.title(f"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}")
plt.plot(wls, T_analytical_initial, label="T theory")
plt.scatter(wls, jnp.abs(transmitted_initial) ** 2, label="T SAX")
plt.plot(wls, R_analytical_initial, label="R theory")
plt.scatter(wls, jnp.abs(reflected_initial) ** 2, label="R SAX")
plt.xlabel("k = 2$\pi$/$\lambda$ [1/nm]")
plt.ylabel("Power (units of input)")
plt.figlegend(framealpha=1.0)
plt.show()
../_images/f7ee7c1aa3c24a9618fb38c0e163bed8de37a76b10e0d2353886c1bed955581e.png

Wavelength-dependent Fabry-Pérot étalon#

Let’s repeat with a model where parameters can be wavelength-dependent. To comply with the optimizer object, we will stack all design parameters in a single array :

ts_initial = jnp.zeros(2 * N)
ts_initial = ts_initial.at[0:N].set(jnp.sqrt(0.5))

We will simply loop over all wavelengths, and use different \(t\) parameters at each wavelength.

wls = jnp.linspace(0.38, 0.78, N)
transmitted = jnp.zeros_like(wls)
reflected = jnp.zeros_like(wls)
settings = sax.get_settings(fabry_perot_tunable)
settings = sax.update_settings(
    settings, wl=wls, t_amp=ts_initial[:N], t_ang=ts_initial[N:]
)
settings["gap"]["ni"] = 1.0
settings["gap"]["di"] = 2.0
# Perform computation
sdict = fabry_perot_tunable(**settings)
transmitted = jnp.abs(sdict["in", "out"]) ** 2
reflected = jnp.abs(sdict["in", "in"]) ** 2
plt.plot(wls * 1e3, T_analytical_initial, label="T theory")
plt.scatter(wls * 1e3, transmitted, label="T SAX")
plt.plot(wls * 1e3, R_analytical_initial, label="R theory")
plt.scatter(wls * 1e3, reflected, label="R SAX")
plt.xlabel("λ [nm]")
plt.ylabel("Transmitted and reflected intensities")
plt.legend(loc="upper right")
plt.title(f"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}")
plt.show()
../_images/0960a66901b18aaa1264819c3838a514a3e61ad35edf724a5c07208ca06dc066.png

Since it seems to work, let’s add a target and optimize some harmonics away :

def lorentzian(l0, dl, wl, A):
    return A / ((wl - l0) ** 2 + (0.5 * dl) ** 2)


target = lorentzian(533.0, 20.0, wls * 1e3, 100.0)

plt.scatter(wls * 1e3, transmitted, label="T SAX")
plt.scatter(wls * 1e3, reflected, label="R SAX")
plt.plot(wls * 1e3, target, "r", linewidth=2, label="target")
plt.xlabel("λ [nm]")
plt.ylabel("Transmitted and reflected intensities")
plt.legend(loc="upper right")
plt.title(f"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}")
plt.show()
../_images/725e2e031b4c297d2a6d237461bc808162f3e676a9d1bb9fa263e6dcb01ea7b1.png
@jax.jit
def loss(ts):
    N = len(ts[::2])
    wls = jnp.linspace(0.38, 0.78, N)
    target = lorentzian(533.0, 20.0, wls * 1e3, 100.0)
    settings = sax.get_settings(fabry_perot_tunable)
    settings = sax.update_settings(settings, wl=wls, t_amp=ts[:N], t_ang=ts[N:])
    settings["gap"]["ni"] = 1.0
    settings["gap"]["di"] = 2.0
    sdict = fabry_perot_tunable(**settings)
    transmitted = jnp.abs(sdict["in", "out"]) ** 2
    return (jnp.abs(transmitted - target) ** 2).mean()
grad = jax.jit(jax.grad(loss))
optim_init, optim_update, optim_params = opt.adam(step_size=0.001)


def train_step(step, optim_state):
    ts = optim_params(optim_state)
    lossvalue = loss(ts)
    gradvalue = grad(ts)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, gradvalue, optim_state
range_ = tqdm.trange(2000)

optim_state = optim_init(ts_initial)
for step in range_:
    lossvalue, gradvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")

The optimized parameters are now wavelength-dependent :

ts_optimal = optim_params(optim_state)

plt.scatter(wls * 1e3, ts_initial[:N], label="t initial")
plt.scatter(wls * 1e3, ts_optimal[:N], label="t optimal")
plt.xlabel("λ [nm]")
plt.ylabel("|t| $(\lambda)$")
plt.legend(loc="best")
plt.title(f"d={d_gap} nm, n={n_gap}")
plt.show()

plt.scatter(wls * 1e3, ts_initial[N:], label="t initial")
plt.scatter(wls * 1e3, ts_optimal[N:], label="t optimal")
plt.xlabel("λ [nm]")
plt.ylabel("angle $t (\lambda)$ (rad)")
plt.legend(loc="best")
plt.title(f"d={d_gap} nm, n={n_gap}")
plt.show()
../_images/6c37255ca839ec8cb930679c0131bd6b53c0d4b6856ef28656d7a6d256306cf8.png ../_images/46e028836df1864993f4cafdd7b39205fede9611ba32a0417d3a963595fd3e31.png

Visualizing the result :

wls = jnp.linspace(0.38, 0.78, N)
transmitted_optimal = jnp.zeros_like(wls)
reflected_optimal = jnp.zeros_like(wls)

settings = sax.get_settings(fabry_perot_tunable)
settings = sax.update_settings(
    settings, wl=wls, t_amp=ts_optimal[:N], t_ang=ts_optimal[N:]
)
settings["gap"]["ni"] = 1.0
settings["gap"]["di"] = 2.0
transmitted_optimal = jnp.abs(fabry_perot_tunable(**settings)["in", "out"]) ** 2
reflected_optimal = jnp.abs(fabry_perot_tunable(**settings)["in", "in"]) ** 2
plt.scatter(wls * 1e3, transmitted_optimal, label="T")
plt.scatter(wls * 1e3, reflected_optimal, label="R")
plt.plot(wls * 1e3, lorentzian(533, 20, wls * 1e3, 100), "r", label="target")
plt.xlabel("λ [nm]")
plt.ylabel("Transmitted and reflected intensities")
plt.legend(loc="upper right")
plt.title(f"Optimized t($\lambda$), d={d_gap} nm, n={n_gap}")
plt.show()
../_images/529524b7fc79fb6c3cbe9dffad8443a9a64bb70cf3c3317c9d64ff6995b78ca7.png

The hard part is now to find physical stacks that physically implement \(t(\lambda)\). However, the ease with which we can modify and complexify the loss function opens opportunities for regularization and more complicated objective functions.

The models above are available in sax.models.thinfilm, and can straightforwardly be extended to propagation at an angle, s and p polarizations, nonreciprocal systems, and systems with losses.