Solve multiple linear systems at once

klujax can solve many linear systems simultaneously. This is useful when you have multiple matrices sharing the same sparsity pattern, or multiple right-hand sides for the same matrix.

Multiple Right-Hand Sides

Solve Ax₁ = b₁ and Ax₂ = b₂ with the same A:

import klujax
import jax.numpy as jnp

Ai = jnp.array([0, 1, 2], dtype=jnp.int32)
Aj = jnp.array([0, 1, 2], dtype=jnp.int32)
Ax = jnp.array([2.0, 3.0, 4.0])

# b has shape (3, 2) — 3 rows, 2 right-hand sides
b = jnp.array([
    [6.0, 10.0],
    [9.0, 15.0],
    [12.0, 20.0],
])

x = klujax.solve(Ai, Aj, Ax, b)
# x[:, 0] = [3, 3, 3]  (solution for b₁)
# x[:, 1] = [5, 5, 5]  (solution for b₂)
flowchart LR A["A #40;one matrix#41;"] --> S["solve"] B["b₁, b₂\n#40;2 right-hand sides#41;"] --> S S --> X["x₁, x₂\n#40;2 solutions#41;"] style S fill:#6366f1,color:#fff,stroke:none

Multiple Matrices (Batched Ax)

Solve different matrices (same pattern, different values) against different right-hand sides:

# 3 different diagonal matrices
Ax = jnp.array([
    [2.0, 3.0, 4.0],  # diag(2, 3, 4)
    [1.0, 1.0, 1.0],  # identity
    [4.0, 6.0, 8.0],  # diag(4, 6, 8)
])

# 3 corresponding right-hand sides
b = jnp.array([
    [6.0, 9.0, 12.0],
    [5.0, 5.0, 5.0],
    [8.0, 12.0, 16.0],
])

x = klujax.solve(Ai, Aj, Ax, b)
# x[0] = [3, 3, 3]
# x[1] = [5, 5, 5]
# x[2] = [2, 2, 2]
flowchart LR A1["A₁, A₂, A₃\n#40;3 matrices#41;"] --> S["solve"] B1["b₁, b₂, b₃\n#40;3 right-hand sides#41;"] --> S S --> X1["x₁, x₂, x₃\n#40;3 solutions#41;"] style S fill:#6366f1,color:#fff,stroke:none

Broadcasting: One Matrix, Many b's

When Ax is 2D (batched) but b is 1D, the same b is used for all matrices:

Ax = jnp.array([
    [2.0, 3.0, 4.0],
    [1.0, 1.0, 1.0],
])
b = jnp.array([6.0, 9.0, 12.0])  # same b for both

x = klujax.solve(Ai, Aj, Ax, b)
# x[0] = [3, 3, 3]   (solved with Ax[0])
# x[1] = [6, 9, 12]  (solved with Ax[1] = identity)

Using vmap for Extra Dimensions

For dimensions beyond what the built-in batching supports, use jax.vmap:

import jax

# 4D: outer batch of 5, inner batch of 3
Ax_4d = jnp.ones((5, 3, 3))  # (5, n_lhs=3, n_nz=3)
b_4d = jnp.ones((5, 3, 3))   # (5, n_lhs=3, n_col=3)

# vmap over the outer dimension
x_4d = jax.vmap(klujax.solve, in_axes=(None, None, 0, 0))(Ai, Aj, Ax_4d, b_4d)

vmap Axis Options

# Batch over Ax only (same b for all)
x = jax.vmap(klujax.solve, in_axes=(None, None, 0, None))(Ai, Aj, Ax_batch, b)

# Batch over b only (same matrix for all)
x = jax.vmap(klujax.solve, in_axes=(None, None, None, 0))(Ai, Aj, Ax, b_batch)

# Batch over both
x = jax.vmap(klujax.solve, in_axes=(None, None, 0, 0))(Ai, Aj, Ax_batch, b_batch)

Note

Ai and Aj are never batched — all systems share the same sparsity pattern. Always use in_axes=None for the index arrays.

Complete Example: Parameter Sweep

import jax
import klujax
import jax.numpy as jnp

# Base matrix pattern
Ai = jnp.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32)
Aj = jnp.array([0, 1, 0, 1, 1, 2], dtype=jnp.int32)
base_Ax = jnp.array([4.0, -1.0, -1.0, 4.0, -1.0, 4.0])

# Sweep over 100 scaling factors
scales = jnp.linspace(0.5, 2.0, 100)
Ax_sweep = scales[:, None] * base_Ax[None, :]  # (100, 6)

b = jnp.array([1.0, 0.0, 0.0])

# Solve all 100 systems at once
x_all = klujax.solve(Ai, Aj, Ax_sweep, b)
# x_all has shape (100, 3)