Skip to content

Data Parsers

Let's parse some data.

Imports

from functools import cache

import altair as alt
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import requests
import sax

Lumerical Parser

The SiEPIC ebeam PDK has a bunch of data files in Lumerical format. Let's download one of them:

url = "https://raw.githubusercontent.com/SiEPIC/SiEPIC_EBeam_PDK/refs/heads/master/Lumerical_EBeam_CML/EBeam/source_data/ebeam_dc_te1550/dc_gap%3D200nm_Lc%3D0um.sparam"
content = requests.get(url).text
print(content[:1000])
('port 1','TE',1,'port 1',1,'transmission')
(101,3)
1.8737e+014 0.0165931   -3.10307
1.87495e+014    0.0150283   -2.93033
1.8762e+014 0.0133308   -2.74368
1.87745e+014    0.0115503   -2.5359
1.8787e+014 0.00975763  -2.29485
1.87995e+014    0.00806292  -2.0002
1.8812e+014 0.00665148  -1.6222
1.88245e+014    0.00581419  -1.14149
1.8837e+014 0.0058327   -0.609826
1.88495e+014    0.00667996  -0.137315
1.88619e+014    0.0080483   0.232713
1.88744e+014    0.00965048  0.522573
1.88869e+014    0.0113097   0.761433
1.88994e+014    0.0129225   0.968779
1.89119e+014    0.014424    1.15621
1.89244e+014    0.01577 1.33067
1.89369e+014    0.016928    1.49646
1.89494e+014    0.0178741   1.65641
1.89619e+014    0.0185905   1.81246
1.89744e+014    0.019065    1.96603
1.89869e+014    0.01929 2.11824
1.89993e+014    0.0192627   2.27007
1.90118e+014    0.0189848   2.42242
1.90243e+014    0.0184622   2.57626
1.90368e+014    0.0177056   2.73266
1.90493e+014    0.0167299   2.89302
1.90618e+014    0.0155552   3.05919
1.90743e+014    0.0142071   3.23383
1.90868e+014    0.0127182   3.42095
1.90993e+014    0.0111316   3.62693
1.91118e+01
df = sax.parse_lumerical_dat(content, convert_f_to_wl=True)
df
/tmp/ipykernel_2344/1628850161.py:1: ExperimentalWarning: The [`parse_lumerical_dat`][sax.parse_lumerical_dat] function is experimental. If you encounter any issues, Please file a bug report here: https://github.com/flaport/sax/issues .
  df = sax.parse_lumerical_dat(content, convert_f_to_wl=True)
wl port_in port_out mode_in mode_out amp phi
0 1.600002 port_1 port_1 1 1 0.016593 -3.10307
1 1.598936 port_1 port_1 1 1 0.015028 -2.93033
2 1.597870 port_1 port_1 1 1 0.013331 -2.74368
3 1.596807 port_1 port_1 1 1 0.011550 -2.53590
4 1.595744 port_1 port_1 1 1 0.009758 -2.29485
... ... ... ... ... ... ... ...
1611 1.503759 port_4 port_4 1 1 0.015607 20.33110
1612 1.502817 port_4 port_4 1 1 0.016338 20.50640
1613 1.501876 port_4 port_4 1 1 0.016867 20.67810
1614 1.500936 port_4 port_4 1 1 0.017183 20.84810
1615 1.499997 port_4 port_4 1 1 0.017282 21.01760

1616 rows × 7 columns

We see that the parsed dataframe is a dataframe in tidy format with the following columns:

'wl', 'port_in', 'port_out', 'mode_in', 'mode_out', 'amp', 'phi'

In this case it's a single mode dataframe:

print(f"{df.mode_in.unique()=}")
print(f"{df.mode_out.unique()=}")
df.mode_in.unique()=array([1])
df.mode_out.unique()=array([1])

So if we want we can drop those columns:

df = sax.parse_lumerical_dat(content, convert_f_to_wl=True)
df = df.drop(columns=["mode_in", "mode_out"])
df
/tmp/ipykernel_2344/2883338943.py:1: ExperimentalWarning: The [`parse_lumerical_dat`][sax.parse_lumerical_dat] function is experimental. If you encounter any issues, Please file a bug report here: https://github.com/flaport/sax/issues .
  df = sax.parse_lumerical_dat(content, convert_f_to_wl=True)
wl port_in port_out amp phi
0 1.600002 port_1 port_1 0.016593 -3.10307
1 1.598936 port_1 port_1 0.015028 -2.93033
2 1.597870 port_1 port_1 0.013331 -2.74368
3 1.596807 port_1 port_1 0.011550 -2.53590
4 1.595744 port_1 port_1 0.009758 -2.29485
... ... ... ... ... ...
1611 1.503759 port_4 port_4 0.015607 20.33110
1612 1.502817 port_4 port_4 0.016338 20.50640
1613 1.501876 port_4 port_4 0.016867 20.67810
1614 1.500936 port_4 port_4 0.017183 20.84810
1615 1.499997 port_4 port_4 0.017282 21.01760

1616 rows × 5 columns

The plotting library altair is a perfect fit for visualizing dataframes in tidy format:

chart = (
    alt.Chart(df.query("port_in=='port_1'"))
    .mark_line()
    .encode(
        x=alt.X("wl", scale=alt.Scale(domain=(df["wl"].min(), df["wl"].max()))),
        y=alt.Y("amp", scale=alt.Scale(domain=(-0.05, 1.05))),
        color="port_out",
    )
    .properties(width="container")
).interactive()
chart

Transforming into xarray

Very often we would like to represent this as an xarray (think of it as a multi-dimensional dataframe):

xarr = sax.to_xarray(df, target_names=["amp", "phi"])
print(xarr)
<xarray.DataArray (wl: 101, port_in: 4, port_out: 4, targets: 2)> Size: 26kB
Array([[[[ 1.72823e-02,  2.10176e+01],
         [ 1.83659e-03,  2.22772e+01],
         [ 9.88816e-01,  9.76622e+00],
         [ 6.10558e-02,  1.13450e+01]],

        [[ 1.83659e-03,  2.22772e+01],
         [ 1.72823e-02,  2.10176e+01],
         [ 6.10558e-02,  1.13450e+01],
         [ 9.88816e-01,  9.76622e+00]],

        [[ 9.88816e-01,  9.76622e+00],
         [ 6.10558e-02,  1.13450e+01],
         [ 1.72823e-02,  2.10176e+01],
         [ 1.83659e-03,  2.22772e+01]],

        [[ 6.10558e-02,  1.13450e+01],
         [ 9.88816e-01,  9.76622e+00],
         [ 1.83659e-03,  2.22772e+01],
         [ 1.72823e-02,  2.10176e+01]]],

...

       [[[ 1.65931e-02, -3.10307e+00],
         [ 2.30076e-03, -1.03462e+00],
         [ 9.85663e-01, -1.95284e+00],
         [ 1.06185e-01, -4.04115e-01]],

        [[ 2.30076e-03, -1.03462e+00],
         [ 1.65931e-02, -3.10307e+00],
         [ 1.06185e-01, -4.04115e-01],
         [ 9.85663e-01, -1.95284e+00]],

        [[ 9.85663e-01, -1.95284e+00],
         [ 1.06185e-01, -4.04115e-01],
         [ 1.65931e-02, -3.10307e+00],
         [ 2.30076e-03, -1.03462e+00]],

        [[ 1.06185e-01, -4.04115e-01],
         [ 9.85663e-01, -1.95284e+00],
         [ 2.30076e-03, -1.03462e+00],
         [ 1.65931e-02, -3.10307e+00]]]], dtype=float64)
Coordinates:
  * wl        (wl) float64 808B 1.5 1.501 1.502 1.503 ... 1.597 1.598 1.599 1.6
  * port_in   (port_in) object 32B 'port_1' 'port_2' 'port_3' 'port_4'
  * port_out  (port_out) object 32B 'port_1' 'port_2' 'port_3' 'port_4'
  * targets   (targets) object 16B 'amp' 'phi'

Interpolating an xarray:

To interpolate over the float coordinates of the xarray:

sax.interpolate_xarray(xarr, wl=[1.5, 1.6])
{'amp': Array([[[0.01728201, 0.01658951],
         [0.00183677, 0.00230078],
         [0.98881583, 0.98566318],
         [0.06105652, 0.10618373]],

        [[0.00183677, 0.00230078],
         [0.01728201, 0.01658951],
         [0.06105652, 0.10618373],
         [0.98881583, 0.98566318]],

        [[0.98881583, 0.98566318],
         [0.06105652, 0.10618373],
         [0.01728201, 0.01658951],
         [0.00183677, 0.00230078]],

        [[0.06105652, 0.10618373],
         [0.98881583, 0.98566318],
         [0.00183677, 0.00230078],
         [0.01728201, 0.01658951]]], dtype=float64),
 'phi': Array([[[21.01711034, -3.10267416],
         [22.27653325, -1.03404903],
         [ 9.76588091, -1.95257198],
         [11.34465825, -0.40384983]],

        [[22.27653325, -1.03404903],
         [21.01711034, -3.10267416],
         [11.34465825, -0.40384983],
         [ 9.76588091, -1.95257198]],

        [[ 9.76588091, -1.95257198],
         [11.34465825, -0.40384983],
         [21.01711034, -3.10267416],
         [22.27653325, -1.03404903]],

        [[11.34465825, -0.40384983],
         [ 9.76588091, -1.95257198],
         [22.27653325, -1.03404903],
         [21.01711034, -3.10267416]]], dtype=float64)}

String coordinates can not be interpolated over, but they can be selected:

sax.interpolate_xarray(xarr, wl=[1.5, 1.6], port_in="port_1", port_out="port_1")
{'amp': Array([0.01728201, 0.01658951], dtype=float64),
 'phi': Array([21.01711034, -3.10267416], dtype=float64)}

or to have all outputs for a certain input:

sax.interpolate_xarray(xarr, wl=[1.5, 1.6], port_in="port_1")
{'amp': Array([[0.01728201, 0.01658951],
        [0.00183677, 0.00230078],
        [0.98881583, 0.98566318],
        [0.06105652, 0.10618373]], dtype=float64),
 'phi': Array([[21.01711034, -3.10267416],
        [22.27653325, -1.03404903],
        [ 9.76588091, -1.95257198],
        [11.34465825, -0.40384983]], dtype=float64)}

Creating a model

Using all of the above we can create a model. The common boilerplate can be divided in two steps:

Important

This is the recommended way to create a sax model from a data file!

# 1. The cached data loader:


@cache
def load_dc_xarray():
    #
    url = url = (
        "https://raw.githubusercontent.com/SiEPIC/SiEPIC_EBeam_PDK/refs/heads/master/Lumerical_EBeam_CML/EBeam/source_data/ebeam_dc_te1550/dc_gap%3D200nm_Lc%3D0um.sparam"
    )
    content = requests.get(url).text
    # or for local data probably more something like this:
    # path = Path(__file__).parent / "relative" / "path" / "to" / "data.dat"
    # content = Path(path).read_text()
    df = sax.parse_lumerical_dat(content, convert_f_to_wl=True)

    # only keep columns that should be used
    # (i.e. columns that uniquely predict the target, without duplication, e.g. no 'f' and 'wl' together)
    df = df.drop(columns=["mode_in", "mode_out"])

    # now we can transform to xarray
    xarr = sax.to_xarray(df, target_names=["amp", "phi"])

    # and return it
    return xarr


# 2. The model function
@jax.jit  # if you can, try to jit it
def dc_model(
    wl=1.5,  # all non-port, non-target columns should be exposed as keyword arguments
) -> sax.SDict:
    with jax.ensure_compile_time_eval():
        xarr = load_dc_xarray()

    ports = {
        "in0": "port_1",
        "in1": "port_2",
        "out0": "port_4",
        "out1": "port_3",
    }

    S = {}
    for p_in, port_in in ports.items():
        for p_out, port_out in ports.items():
            # don't forget to add more keyword arguments here if your data supports it!
            interpolated = sax.interpolate_xarray(
                xarr, wl=wl, port_in=str(port_in), port_out=str(port_out)
            )
            S[p_in, p_out] = interpolated["amp"] * jnp.exp(1j * interpolated["phi"])
    return S

Tip

If you don't like that we're linearly interpolating the data in the above example, then have a look at the 'surface models' example. There we go into more depth on how to fit an accurate analytical model on the data.

result = dc_model()
result
/tmp/ipykernel_2344/1990555231.py:14: ExperimentalWarning: The [`parse_lumerical_dat`][sax.parse_lumerical_dat] function is experimental. If you encounter any issues, Please file a bug report here: https://github.com/flaport/sax/issues .
  df = sax.parse_lumerical_dat(content, convert_f_to_wl=True)





{('in0', 'in0'): Array(-0.00971187+0.01429502j, dtype=complex128),
 ('in0', 'in1'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('in0', 'out0'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('in0', 'out1'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('in1', 'in0'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('in1', 'in1'): Array(-0.00971187+0.01429502j, dtype=complex128),
 ('in1', 'out0'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('in1', 'out1'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('out0', 'in0'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('out0', 'in1'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('out0', 'out0'): Array(-0.00971187+0.01429502j, dtype=complex128),
 ('out0', 'out1'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('out1', 'in0'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('out1', 'in1'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('out1', 'out0'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('out1', 'out1'): Array(-0.00971187+0.01429502j, dtype=complex128)}

SDense for performance

A model returning an SDict is usually the easiest to work with, however, we can also return an SDense, which in this case should be more performant, as only one xarray interpolation will be necessary:

@jax.jit
def dc_model2(
    wl=1.5,
) -> (
    sax.SDense
):  # all non-port, non-target columns should be exposed as keyword arguments
    with jax.ensure_compile_time_eval():
        xarr = load_dc_xarray()

    ports = {
        "port_1": "in0",
        "port_2": "in1",
        "port_4": "out0",
        "port_3": "out1",
    }

    # by not specifying ports, the array will be interpolated directly:
    # NOTE! for this to work, you should confirm that the last three dimensions
    # the last two dimensions of your xarray (`xarr.dims`) are port_in, port_out, targets
    interpolated = sax.interpolate_xarray(
        xarr, wl=wl
    )  # you can add more keyword args here
    S = interpolated["amp"] * jnp.exp(1j * interpolated["phi"])
    port_map = {ports[k]: i for i, k in enumerate(xarr.coords["port_in"].values)}
    # also confirm that if we define port_map with port_out instead, we get the same dict!
    # port_map = {k: i for i, k in enumerate(xarr.coords['port_out'].values)}
    return S.T, port_map  # this is a an SDense!
result2 = sax.sdict(dc_model2())
result2
{('in0', 'in0'): Array(-0.00971187+0.01429502j, dtype=complex128),
 ('in0', 'in1'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('in0', 'out0'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('in0', 'out1'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('in1', 'in0'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('in1', 'in1'): Array(-0.00971187+0.01429502j, dtype=complex128),
 ('in1', 'out0'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('in1', 'out1'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('out0', 'in0'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('out0', 'in1'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('out0', 'out0'): Array(-0.00971187+0.01429502j, dtype=complex128),
 ('out0', 'out1'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('out1', 'in0'): Array(-0.93184646-0.33078529j, dtype=complex128),
 ('out1', 'in1'): Array(0.0208836-0.05737398j, dtype=complex128),
 ('out1', 'out0'): Array(-0.00176248-0.0005171j, dtype=complex128),
 ('out1', 'out1'): Array(-0.00971187+0.01429502j, dtype=complex128)}
for k in result:
    print(k, abs(result[k]), abs(result2[k]), abs(result[k]) - abs(result2[k]))
('in0', 'in0') 0.017282013135807243 0.017282013135807243 0.0
('in0', 'in1') 0.0018367690229005554 0.0018367690229005554 0.0
('in0', 'out0') 0.06105651990490266 0.06105651990490266 0.0
('in0', 'out1') 0.9888158266681615 0.9888158266681615 0.0
('in1', 'in0') 0.0018367690229005554 0.0018367690229005554 0.0
('in1', 'in1') 0.017282013135807243 0.017282013135807243 0.0
('in1', 'out0') 0.9888158266681615 0.9888158266681615 0.0
('in1', 'out1') 0.06105651990490266 0.06105651990490266 0.0
('out0', 'in0') 0.06105651990490266 0.06105651990490266 0.0
('out0', 'in1') 0.9888158266681615 0.9888158266681615 0.0
('out0', 'out0') 0.017282013135807243 0.017282013135807243 0.0
('out0', 'out1') 0.0018367690229005554 0.0018367690229005554 0.0
('out1', 'in0') 0.9888158266681615 0.9888158266681615 0.0
('out1', 'in1') 0.06105651990490266 0.06105651990490266 0.0
('out1', 'out0') 0.0018367690229005554 0.0018367690229005554 0.0
('out1', 'out1') 0.017282013135807243 0.017282013135807243 0.0

We can plot this result (note: we're plotting amplitudes now):

wl = sax.wl_c()
result = dc_model(wl=wl)
df = (
    sax.to_df(result, wl=wl)
    .query('port_in=="in0"')
    .drop(columns=["port_in", "mode_in", "mode_out"])
    .reset_index(drop=True)
)
chart = (
    alt.Chart(df)
    .mark_line()
    .encode(
        x=alt.X("wl", scale=alt.Scale(domain=(df["wl"].min(), df["wl"].max()))),
        y=alt.Y("amp", scale=alt.Scale(domain=(-0.05, 1.05))),
        color="port_out",
    )
    .properties(width="container")
).interactive()
chart

Touchstone parser

We also have a touchstone parser which returns a dataframe. For example to open a 6x6 touchstone S-matrix you could do something like this:

sax.parse_touchstone(
    "./<filename>.s6p",
    ports=["in0", "in1", "out0", "out1", "out2", "out3"], # you need to supply port labels for this format
)

You can use a similar approach as the above to use touchstone file to create your models.