Solve Ax = b using a pre-computed numeric factorization. This is the fastest solve path — it skips both the analyze and factor steps, and only performs forward/backward substitution.
Parameters¶
| Parameter | Type | Shape | Description |
|---|---|---|---|
numeric |
KLUHandleManager | — | Handle from factor or refactor |
b |
float64 or complex128 | (n_lhs?, n_col, n_rhs?) |
Right-hand side |
symbolic |
KLUHandleManager | — | Handle from analyze |
Returns¶
| Type | Shape | Description |
|---|---|---|
| Array | Same shape as b |
The solution x |
How It Fits In¶
flowchart LR
subgraph "Setup (once)"
AN["analyze"] --> SYM["symbolic"]
SYM --> FA["factor#40;Ax#41;"] --> NUM["numeric"]
end
subgraph "Hot loop (many times)"
NUM --> SWN["solve_with_numeric\n#40;numeric, b_t, symbolic#41;"]:::active
B["b_t"] --> SWN
SWN --> X["x_t"]
end
classDef active fill:#10b981,color:#fff,stroke:none
Example¶
import jax
import klujax
import jax.numpy as jnp
Ai = jnp.array([0, 1, 2], dtype=jnp.int32)
Aj = jnp.array([0, 1, 2], dtype=jnp.int32)
Ax = jnp.array([2.0, 3.0, 4.0])
n_col = 3
# Setup: analyze + factor (done once)
symbolic = klujax.analyze(Ai, Aj, n_col)
numeric = klujax.factor(Ai, Aj, Ax, symbolic)
# Hot loop: only substitution (very fast)
@jax.jit
def fast_solve(b, num, sym):
return klujax.solve_with_numeric(num, b, sym)
for i in range(10_000):
x = fast_solve(b_values[i], numeric, symbolic)
Performance Comparison¶
flowchart LR
subgraph "solve"
direction LR
A1["Analyze"] --> F1["Factor"] --> S1["Solve"]
end
subgraph "solve_with_symbol"
direction LR
F2["Factor"] --> S2["Solve"]
end
subgraph "solve_with_numeric"
direction LR
S3["Solve"]
end
style A1 fill:#ef4444,color:#fff,stroke:none
style F1 fill:#ef4444,color:#fff,stroke:none
style S1 fill:#10b981,color:#fff,stroke:none
style F2 fill:#ef4444,color:#fff,stroke:none
style S2 fill:#10b981,color:#fff,stroke:none
style S3 fill:#10b981,color:#fff,stroke:none
This is the fastest option. Only forward/backward substitution is performed — no pattern analysis, no LU decomposition. Use this when the matrix A stays constant and only b changes.
Batched Solve¶
If numeric was created from batched Ax (shape (n_lhs, n_nz)), you can solve all batches at once:
# numeric wraps 5 factorizations
Ax_batch = jnp.ones((5, 3)) * jnp.arange(1, 6)[:, None]
numeric = klujax.factor(Ai, Aj, Ax_batch, symbolic)
# Solve all 5 systems
b_batch = jnp.ones((5, 3))
x_batch = klujax.solve_with_numeric(numeric, b_batch, symbolic)
You can also broadcast: a single numeric (from 1D Ax) can solve against batched b:
numeric = klujax.factor(Ai, Aj, Ax, symbolic) # single factorization
b_batch = jnp.ones((10, 3, 2)) # 10 batches, 2 right-hand sides each
x_batch = klujax.solve_with_numeric(numeric, b_batch, symbolic)
JAX Features¶
| Feature | Supported |
|---|---|
jax.jit |
Yes |
jax.vmap |
Yes |