Skip to content

Sparse COO

SAX supports three S-parameter representations: SDict (dictionary), SDense (dense array), and SCoo (sparse COO format). The COO format is the most memory-efficient for large, sparse circuits but requires more careful handling. This notebook explains when and how to use the SCoo format.

import time

import jax
import jax.numpy as jnp

import sax

S-Parameter Representations

SAX provides three ways to represent S-parameters:

Format Type Best For Memory Access
SDict Dict[Tuple[str, str], Array] Interactive use, debugging Medium O(1) by port name
SDense Tuple[Array, Dict] Dense matrices, small circuits High O(1) by index
SCoo Tuple[Array, Array, Array, Dict] Large sparse circuits Low Sparse iteration

The COO (Coordinate) format stores only non-zero elements as three arrays: - i: Row indices - j: Column indices
- data: S-parameter values

Plus a port map dictionary that maps port names to indices.

Creating SCoo Models

Here's how to create a model that returns SCoo format directly. This is useful when you need maximum performance for large circuits:

def my_coo():
    """A 5-port splitter in SCoo format (4 inputs, 1 output)."""
    num_input_ports = 4
    num_output_ports = 1

    # Port map: maps port names to matrix indices
    pm = {
        **{f"in{i}": i for i in range(num_input_ports)},
        **{f"out{i}": i + num_input_ports for i in range(num_output_ports)},
    }

    # Non-zero transmission values (equal splitting)
    thru = jnp.ones(num_input_ports) * 0.5  # Equal coupling

    # COO indices: each input connects to the single output
    i = jnp.arange(0, num_input_ports, 1)
    j = jnp.zeros_like(i) + num_input_ports

    # Make reciprocal by duplicating with swapped indices
    i, j = jnp.concatenate([i, j]), jnp.concatenate([j, i])
    thru = jnp.concatenate([thru, thru], 0)

    return (i, j, thru, pm)


# Let's see what this returns
result = my_coo()
print("Row indices (i):", result[0])
print("Col indices (j):", result[1])
print("Values:", result[2])
print("Port map:", result[3])
Row indices (i): [0 1 2 3 4 4 4 4]
Col indices (j): [4 4 4 4 0 1 2 3]
Values: [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
Port map: {'in0': 0, 'in1': 1, 'in2': 2, 'in3': 3, 'out0': 4}

Using SCoo in Circuits

You can use SCoo models in circuits just like any other model. Use return_type="scoo" to get SCoo output:

circuit, _ = sax.circuit(
    netlist={
        "instances": {
            "coo": "coo",
        },
        "connections": {},
        "ports": {
            "in0": "coo,in0",
            "in1": "coo,in1",
            "in2": "coo,in2",
            "in3": "coo,in3",
            "out0": "coo,out0",
        },
    },
    models={
        "coo": my_coo,
    },
    backend="klu",
    return_type="scoo",
)

scoo_result = circuit()
print("SCoo result:", scoo_result)
SCoo result: (Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4], dtype=int64), Array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1,
       2, 3, 4], dtype=int64), Array([0. +0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0.5+0.j, 0. +0.j, 0. +0.j,
       0. +0.j, 0. +0.j, 0.5+0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
       0.5+0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0.5+0.j, 0.5+0.j,
       0.5+0.j, 0.5+0.j, 0.5+0.j, 0. +0.j], dtype=complex128), {'in0': 0, 'in1': 1, 'in2': 2, 'in3': 3, 'out0': 4})

Converting Between Formats

You can easily convert between formats using SAX utilities:

# Convert SCoo to SDict for easier inspection
sdict_result = sax.sdict(scoo_result)
print("As SDict:")
for (p1, p2), val in sdict_result.items():
    print(f"  {p1} -> {p2}: {val}")

# Convert to SDense (dense matrix)
sdense_result = sax.sdense(scoo_result)
print("\nAs SDense (matrix, port_map):")
print("Matrix shape:", sdense_result[0].shape)
print("Port map:", sdense_result[1])
As SDict:
  in0 -> in0: 0j
  in0 -> in1: 0j
  in0 -> in2: 0j
  in0 -> in3: 0j
  in0 -> out0: (0.5+0j)
  in1 -> in0: 0j
  in1 -> in1: 0j
  in1 -> in2: 0j
  in1 -> in3: 0j
  in1 -> out0: (0.5+0j)
  in2 -> in0: 0j
  in2 -> in1: 0j
  in2 -> in2: 0j
  in2 -> in3: 0j
  in2 -> out0: (0.5+0j)
  in3 -> in0: 0j
  in3 -> in1: 0j
  in3 -> in2: 0j
  in3 -> in3: 0j
  in3 -> out0: (0.5+0j)
  out0 -> in0: (0.5+0j)
  out0 -> in1: (0.5+0j)
  out0 -> in2: (0.5+0j)
  out0 -> in3: (0.5+0j)
  out0 -> out0: 0j



As SDense (matrix, port_map):
Matrix shape: (5, 5)
Port map: {'in0': 0, 'in1': 1, 'in2': 2, 'in3': 3, 'out0': 4}

Memory Efficiency Comparison

For large sparse circuits, SCoo can use significantly less memory. Let's compare:

def estimate_memory(n_ports, n_nonzero, dtype_size=8):
    """Estimate memory usage for different formats (in bytes)."""
    # SDict: n_nonzero entries, each with tuple overhead + array
    sdict_overhead = n_nonzero * (2 * 50 + dtype_size)  # rough estimate

    # SDense: full n x n matrix
    sdense_size = n_ports * n_ports * dtype_size

    # SCoo: 3 arrays of length n_nonzero + port map
    scoo_size = 3 * n_nonzero * dtype_size + n_ports * 50

    return {
        "SDict": sdict_overhead,
        "SDense": sdense_size,
        "SCoo": scoo_size,
    }


# Example: 1000-port circuit with 2% fill (sparse)
n_ports = 1000
sparsity = 0.02
n_nonzero = int(n_ports * n_ports * sparsity)

memory = estimate_memory(n_ports, n_nonzero)
print(
    f"Circuit: {n_ports} ports, {sparsity * 100:.0f}% fill ({n_nonzero:,} non-zero elements)"
)
print(f"\nEstimated memory usage:")
for fmt, size in memory.items():
    print(f"  {fmt}: {size / 1024:.1f} KB")
print(
    f"\nSCoo is {memory['SDense'] / memory['SCoo']:.1f}x more memory efficient than SDense"
)
Circuit: 1000 ports, 2% fill (20,000 non-zero elements)

Estimated memory usage:
  SDict: 2109.4 KB
  SDense: 7812.5 KB
  SCoo: 517.6 KB

SCoo is 15.1x more memory efficient than SDense

When to Use SCoo

Use SCoo format when:

  1. Large circuits - Circuits with hundreds or thousands of ports
  2. Sparse S-matrices - Most photonic components have sparse S-matrices (few non-zero elements)
  3. Memory constrained - When you need to fit large simulations in limited memory
  4. Performance critical - SCoo operations can be faster due to reduced data movement

Use SDict or SDense when:

  1. Interactive exploration - SDict is most readable and convenient
  2. Small circuits - Overhead of sparse format not worth it
  3. Dense matrices - When most S-matrix elements are non-zero
  4. Debugging - SDict makes it easy to inspect individual port pairs

Performance Example: Large MZI Chain

Let's compare performance with different return types on a larger circuit:

def coupler(coupling=0.5):
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    return sax.reciprocal(
        {
            ("in0", "out0"): tau,
            ("in0", "out1"): 1j * kappa,
            ("in1", "out0"): 1j * kappa,
            ("in1", "out1"): tau,
        }
    )


def waveguide(wl=1.55, length=10.0, neff=2.4):
    phase = 2 * jnp.pi * neff * length / wl
    return sax.reciprocal({("in0", "out0"): jnp.exp(1j * phase)})


# Create an MZI chain
num_mzis = 10

netlist = {
    "instances": {
        **{f"c{i}": "coupler" for i in range(num_mzis + 1)},
        **{f"top{i}": "waveguide" for i in range(num_mzis)},
        **{f"btm{i}": "waveguide" for i in range(num_mzis)},
    },
    "connections": {
        **{f"c{i},out0": f"btm{i},in0" for i in range(num_mzis)},
        **{f"btm{i},out0": f"c{i + 1},in0" for i in range(num_mzis)},
        **{f"c{i},out1": f"top{i},in0" for i in range(num_mzis)},
        **{f"top{i},out0": f"c{i + 1},in1" for i in range(num_mzis)},
    },
    "ports": {
        "in0": "c0,in0",
        "in1": "c0,in1",
        "out0": f"c{num_mzis},out0",
        "out1": f"c{num_mzis},out1",
    },
}

models = {"coupler": coupler, "waveguide": waveguide}

# Create circuits with different return types
circuit_sdict, _ = sax.circuit(netlist, models, return_type="sdict")
circuit_sdense, _ = sax.circuit(netlist, models, return_type="sdense")
circuit_scoo, _ = sax.circuit(netlist, models, return_type="scoo")

# JIT compile all
circuit_sdict_jit = jax.jit(circuit_sdict)
circuit_sdense_jit = jax.jit(circuit_sdense)
circuit_scoo_jit = jax.jit(circuit_scoo)

# Warm up JIT
_ = circuit_sdict_jit()
_ = circuit_sdense_jit()
_ = circuit_scoo_jit()

print(f"MZI chain with {num_mzis} stages created successfully")
MZI chain with 10 stages created successfully
# Time the execution (after JIT warmup)
n_runs = 100

start = time.perf_counter()
for _ in range(n_runs):
    _ = circuit_sdict_jit()
sdict_time = (time.perf_counter() - start) / n_runs * 1000

start = time.perf_counter()
for _ in range(n_runs):
    _ = circuit_sdense_jit()
sdense_time = (time.perf_counter() - start) / n_runs * 1000

start = time.perf_counter()
for _ in range(n_runs):
    _ = circuit_scoo_jit()
scoo_time = (time.perf_counter() - start) / n_runs * 1000

print(f"Execution time per call (average of {n_runs} runs):")
print(f"  SDict:  {sdict_time:.3f} ms")
print(f"  SDense: {sdense_time:.3f} ms")
print(f"  SCoo:   {scoo_time:.3f} ms")
Execution time per call (average of 100 runs):
  SDict:  0.171 ms
  SDense: 0.056 ms
  SCoo:   0.058 ms