Sparse COO¶
SAX supports three S-parameter representations:
SDict(dictionary),SDense(dense array), andSCoo(sparse COO format). The COO format is the most memory-efficient for large, sparse circuits but requires more careful handling. This notebook explains when and how to use the SCoo format.
S-Parameter Representations¶
SAX provides three ways to represent S-parameters:
| Format | Type | Best For | Memory | Access |
|---|---|---|---|---|
SDict |
Dict[Tuple[str, str], Array] |
Interactive use, debugging | Medium | O(1) by port name |
SDense |
Tuple[Array, Dict] |
Dense matrices, small circuits | High | O(1) by index |
SCoo |
Tuple[Array, Array, Array, Dict] |
Large sparse circuits | Low | Sparse iteration |
The COO (Coordinate) format stores only non-zero elements as three arrays:
- i: Row indices
- j: Column indices
- data: S-parameter values
Plus a port map dictionary that maps port names to indices.
Creating SCoo Models¶
Here's how to create a model that returns SCoo format directly. This is useful when you need maximum performance for large circuits:
def my_coo():
"""A 5-port splitter in SCoo format (4 inputs, 1 output)."""
num_input_ports = 4
num_output_ports = 1
# Port map: maps port names to matrix indices
pm = {
**{f"in{i}": i for i in range(num_input_ports)},
**{f"out{i}": i + num_input_ports for i in range(num_output_ports)},
}
# Non-zero transmission values (equal splitting)
thru = jnp.ones(num_input_ports) * 0.5 # Equal coupling
# COO indices: each input connects to the single output
i = jnp.arange(0, num_input_ports, 1)
j = jnp.zeros_like(i) + num_input_ports
# Make reciprocal by duplicating with swapped indices
i, j = jnp.concatenate([i, j]), jnp.concatenate([j, i])
thru = jnp.concatenate([thru, thru], 0)
return (i, j, thru, pm)
# Let's see what this returns
result = my_coo()
print("Row indices (i):", result[0])
print("Col indices (j):", result[1])
print("Values:", result[2])
print("Port map:", result[3])
Row indices (i): [0 1 2 3 4 4 4 4]
Col indices (j): [4 4 4 4 0 1 2 3]
Values: [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]
Port map: {'in0': 0, 'in1': 1, 'in2': 2, 'in3': 3, 'out0': 4}
Using SCoo in Circuits¶
You can use SCoo models in circuits just like any other model. Use return_type="scoo" to get SCoo output:
circuit, _ = sax.circuit(
netlist={
"instances": {
"coo": "coo",
},
"connections": {},
"ports": {
"in0": "coo,in0",
"in1": "coo,in1",
"in2": "coo,in2",
"in3": "coo,in3",
"out0": "coo,out0",
},
},
models={
"coo": my_coo,
},
backend="klu",
return_type="scoo",
)
scoo_result = circuit()
print("SCoo result:", scoo_result)
SCoo result: (Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
4, 4, 4], dtype=int64), Array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1,
2, 3, 4], dtype=int64), Array([0. +0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0.5+0.j, 0. +0.j, 0. +0.j,
0. +0.j, 0. +0.j, 0.5+0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0. +0.j,
0.5+0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0. +0.j, 0.5+0.j, 0.5+0.j,
0.5+0.j, 0.5+0.j, 0.5+0.j, 0. +0.j], dtype=complex128), {'in0': 0, 'in1': 1, 'in2': 2, 'in3': 3, 'out0': 4})
Converting Between Formats¶
You can easily convert between formats using SAX utilities:
# Convert SCoo to SDict for easier inspection
sdict_result = sax.sdict(scoo_result)
print("As SDict:")
for (p1, p2), val in sdict_result.items():
print(f" {p1} -> {p2}: {val}")
# Convert to SDense (dense matrix)
sdense_result = sax.sdense(scoo_result)
print("\nAs SDense (matrix, port_map):")
print("Matrix shape:", sdense_result[0].shape)
print("Port map:", sdense_result[1])
As SDict:
in0 -> in0: 0j
in0 -> in1: 0j
in0 -> in2: 0j
in0 -> in3: 0j
in0 -> out0: (0.5+0j)
in1 -> in0: 0j
in1 -> in1: 0j
in1 -> in2: 0j
in1 -> in3: 0j
in1 -> out0: (0.5+0j)
in2 -> in0: 0j
in2 -> in1: 0j
in2 -> in2: 0j
in2 -> in3: 0j
in2 -> out0: (0.5+0j)
in3 -> in0: 0j
in3 -> in1: 0j
in3 -> in2: 0j
in3 -> in3: 0j
in3 -> out0: (0.5+0j)
out0 -> in0: (0.5+0j)
out0 -> in1: (0.5+0j)
out0 -> in2: (0.5+0j)
out0 -> in3: (0.5+0j)
out0 -> out0: 0j
As SDense (matrix, port_map):
Matrix shape: (5, 5)
Port map: {'in0': 0, 'in1': 1, 'in2': 2, 'in3': 3, 'out0': 4}
Memory Efficiency Comparison¶
For large sparse circuits, SCoo can use significantly less memory. Let's compare:
def estimate_memory(n_ports, n_nonzero, dtype_size=8):
"""Estimate memory usage for different formats (in bytes)."""
# SDict: n_nonzero entries, each with tuple overhead + array
sdict_overhead = n_nonzero * (2 * 50 + dtype_size) # rough estimate
# SDense: full n x n matrix
sdense_size = n_ports * n_ports * dtype_size
# SCoo: 3 arrays of length n_nonzero + port map
scoo_size = 3 * n_nonzero * dtype_size + n_ports * 50
return {
"SDict": sdict_overhead,
"SDense": sdense_size,
"SCoo": scoo_size,
}
# Example: 1000-port circuit with 2% fill (sparse)
n_ports = 1000
sparsity = 0.02
n_nonzero = int(n_ports * n_ports * sparsity)
memory = estimate_memory(n_ports, n_nonzero)
print(
f"Circuit: {n_ports} ports, {sparsity * 100:.0f}% fill ({n_nonzero:,} non-zero elements)"
)
print(f"\nEstimated memory usage:")
for fmt, size in memory.items():
print(f" {fmt}: {size / 1024:.1f} KB")
print(
f"\nSCoo is {memory['SDense'] / memory['SCoo']:.1f}x more memory efficient than SDense"
)
Circuit: 1000 ports, 2% fill (20,000 non-zero elements)
Estimated memory usage:
SDict: 2109.4 KB
SDense: 7812.5 KB
SCoo: 517.6 KB
SCoo is 15.1x more memory efficient than SDense
When to Use SCoo¶
Use SCoo format when:
- Large circuits - Circuits with hundreds or thousands of ports
- Sparse S-matrices - Most photonic components have sparse S-matrices (few non-zero elements)
- Memory constrained - When you need to fit large simulations in limited memory
- Performance critical - SCoo operations can be faster due to reduced data movement
Use SDict or SDense when:
- Interactive exploration - SDict is most readable and convenient
- Small circuits - Overhead of sparse format not worth it
- Dense matrices - When most S-matrix elements are non-zero
- Debugging - SDict makes it easy to inspect individual port pairs
Performance Example: Large MZI Chain¶
Let's compare performance with different return types on a larger circuit:
def coupler(coupling=0.5):
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
return sax.reciprocal(
{
("in0", "out0"): tau,
("in0", "out1"): 1j * kappa,
("in1", "out0"): 1j * kappa,
("in1", "out1"): tau,
}
)
def waveguide(wl=1.55, length=10.0, neff=2.4):
phase = 2 * jnp.pi * neff * length / wl
return sax.reciprocal({("in0", "out0"): jnp.exp(1j * phase)})
# Create an MZI chain
num_mzis = 10
netlist = {
"instances": {
**{f"c{i}": "coupler" for i in range(num_mzis + 1)},
**{f"top{i}": "waveguide" for i in range(num_mzis)},
**{f"btm{i}": "waveguide" for i in range(num_mzis)},
},
"connections": {
**{f"c{i},out0": f"btm{i},in0" for i in range(num_mzis)},
**{f"btm{i},out0": f"c{i + 1},in0" for i in range(num_mzis)},
**{f"c{i},out1": f"top{i},in0" for i in range(num_mzis)},
**{f"top{i},out0": f"c{i + 1},in1" for i in range(num_mzis)},
},
"ports": {
"in0": "c0,in0",
"in1": "c0,in1",
"out0": f"c{num_mzis},out0",
"out1": f"c{num_mzis},out1",
},
}
models = {"coupler": coupler, "waveguide": waveguide}
# Create circuits with different return types
circuit_sdict, _ = sax.circuit(netlist, models, return_type="sdict")
circuit_sdense, _ = sax.circuit(netlist, models, return_type="sdense")
circuit_scoo, _ = sax.circuit(netlist, models, return_type="scoo")
# JIT compile all
circuit_sdict_jit = jax.jit(circuit_sdict)
circuit_sdense_jit = jax.jit(circuit_sdense)
circuit_scoo_jit = jax.jit(circuit_scoo)
# Warm up JIT
_ = circuit_sdict_jit()
_ = circuit_sdense_jit()
_ = circuit_scoo_jit()
print(f"MZI chain with {num_mzis} stages created successfully")
MZI chain with 10 stages created successfully
# Time the execution (after JIT warmup)
n_runs = 100
start = time.perf_counter()
for _ in range(n_runs):
_ = circuit_sdict_jit()
sdict_time = (time.perf_counter() - start) / n_runs * 1000
start = time.perf_counter()
for _ in range(n_runs):
_ = circuit_sdense_jit()
sdense_time = (time.perf_counter() - start) / n_runs * 1000
start = time.perf_counter()
for _ in range(n_runs):
_ = circuit_scoo_jit()
scoo_time = (time.perf_counter() - start) / n_runs * 1000
print(f"Execution time per call (average of {n_runs} runs):")
print(f" SDict: {sdict_time:.3f} ms")
print(f" SDense: {sdense_time:.3f} ms")
print(f" SCoo: {scoo_time:.3f} ms")
Execution time per call (average of 100 runs):
SDict: 0.171 ms
SDense: 0.056 ms
SCoo: 0.058 ms