klujax
A sparse linear solver for JAX based on the KLU algorithm
Note
this documentation was entirely generated by Claude Code.
klujax is a sparse linear solver for JAX built on top of the KLU algorithm from SuiteSparse.
It lets you solve equations of the form Ax = b, where A is a large sparse matrix (mostly zeros), and b is a known vector. The solver finds x.
Why klujax?¶
Sparse matrices show up everywhere in engineering and science: circuit simulation, finite element analysis, power grid modeling, and more. These matrices can be enormous, but most of their entries are zero. KLU is one of the fastest algorithms for solving these kinds of systems, and klujax brings it into the JAX ecosystem with full support for:
- JIT compilation via
jax.jit - Automatic differentiation via
jax.grad,jax.jacfwd,jax.jacrev - Vectorized batching via
jax.vmap
Quick Example¶
import klujax
import jax.numpy as jnp
# A 5x5 sparse matrix in COO format
Ai = jnp.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 4, 4]) # row indices
Aj = jnp.array([0, 1, 0, 2, 4, 1, 2, 3, 2, 1, 2, 4]) # col indices
Ax = jnp.array([2, 3, 3, 4, 6, -1, -3, 2, 1, 4, 2, 1], dtype=jnp.float64)
b = jnp.array([8.0, 45.0, -3.0, 3.0, 19.0])
x = klujax.solve(Ai, Aj, Ax, b)
print(x) # [1. 2. 3. 4. 5.]
At a Glance¶
flowchart LR
A["Sparse Matrix A\n#40;Ai, Aj, Ax#41;"] --> solve["klujax.solve"]
B["Right-hand side b"] --> solve
solve --> X["Solution x"]
style solve fill:#6366f1,color:#fff,stroke:none
Constraints¶
- CPU only — KLU is a CPU algorithm. No GPU support.
- float64 / complex128 only — lower precision inputs are automatically upcast.
- Sparse matrices must be in COO format (coordinate format).
- Duplicate indices must be coalesced before solving (use
klujax.coalesce).
What's Next?¶
- Getting Started — install and run your first solve
- Core Concepts — understand sparse matrices, COO format, and what KLU does
- API Reference — detailed docs for every function
- Examples — practical code you can copy-paste