Re-compute the LU factorization with new matrix values, reusing both the symbolic analysis and the existing numeric handle's memory. This is faster than calling factor again because it updates the factorization in-place.
Parameters¶
| Parameter | Type | Shape | Description |
|---|---|---|---|
Ai |
int32 | (n_nz,) |
Row indices |
Aj |
int32 | (n_nz,) |
Column indices |
Ax |
float64 or complex128 | (n_lhs?, n_nz) |
New matrix values |
numeric |
KLUHandleManager | — | Existing handle from factor or a previous refactor |
symbolic |
KLUHandleManager | — | Handle from analyze |
Returns¶
| Type | Description |
|---|---|
KLUHandleManager |
The same numeric handle, updated in-place with the new factorization |
How It Fits In¶
flowchart TD
AN["analyze"] --> SYM["symbolic"]
SYM --> FA["factor#40;Ax₀#41;"] --> NUM["numeric"]
NUM --> RF["refactor#40;Ax₁, numeric, symbolic#41;"]:::active
RF --> NUM2["numeric #40;updated#41;"]
NUM2 --> RF2["refactor#40;Ax₂, numeric, symbolic#41;"]:::active
RF2 --> NUM3["numeric #40;updated#41;"]
NUM3 --> SWN["solve_with_numeric"]
classDef active fill:#8b5cf6,color:#fff,stroke:none
Example: Simulation Loop¶
import jax
import klujax
import jax.numpy as jnp
Ai = jnp.array([0, 1, 2], dtype=jnp.int32)
Aj = jnp.array([0, 1, 2], dtype=jnp.int32)
n_col = 3
# Setup
symbolic = klujax.analyze(Ai, Aj, n_col)
numeric = klujax.factor(Ai, Aj, Ax_initial, symbolic)
# Simulation loop — matrix values change each step
for t in range(num_steps):
Ax_t = compute_new_matrix_values(t)
numeric = klujax.refactor(Ai, Aj, Ax_t, numeric, symbolic)
x_t = klujax.solve_with_numeric(numeric, b_t, symbolic)
factor vs. refactor¶
| factor | refactor | |
|---|---|---|
| Needs existing numeric? | No | Yes |
| Allocates new memory? | Yes | No (in-place) |
| Speed | Moderate | Slightly faster |
| Use case | First factorization | Subsequent updates |
refactor is most useful in time-stepping simulations where the matrix values change at every step but the sparsity pattern stays the same.
JAX Features¶
| Feature | Supported |
|---|---|
jax.jit |
Yes (via internal JIT wrapper) |
jax.grad |
Yes (w.r.t. Ax) |
jax.vmap |
Yes |
jax.pmap |
Yes |