Sum duplicate entries in a COO sparse matrix. If multiple entries share the same (i, j) position, they are replaced by a single entry whose value is their sum.
Parameters¶
| Parameter | Type | Shape | Description |
|---|---|---|---|
Ai |
int32 | (n_nz,) |
Row indices (may contain duplicates) |
Aj |
int32 | (n_nz,) |
Column indices (may contain duplicates) |
Ax |
float64 or complex128 | (n_nz,) or (n_lhs, n_nz) |
Values |
Returns¶
| Type | Description |
|---|---|
(Ai, Aj, Ax) |
Deduplicated COO arrays. n_nz is reduced to the number of unique positions. |
Why You Need This¶
klujax requires coalesced input — each (row, column) pair must appear at most once. If your data has duplicates (common when assembling finite element matrices), you must coalesce first.
flowchart LR
subgraph Before
B["#40;0,1#41;: 3.0\n#40;0,1#41;: 1.0\n#40;1,2#41;: 4.0"]
end
subgraph After
A["#40;0,1#41;: 4.0\n#40;1,2#41;: 4.0"]
end
Before -->|coalesce| After
Example¶
import klujax
import jax.numpy as jnp
# Two entries at position (0, 1)
Ai = jnp.array([0, 0, 1], dtype=jnp.int32)
Aj = jnp.array([0, 0, 1], dtype=jnp.int32)
Ax = jnp.array([2.0, 3.0, 4.0])
Ai, Aj, Ax = klujax.coalesce(Ai, Aj, Ax)
# Ai = [0, 1]
# Aj = [0, 1]
# Ax = [5.0, 4.0] (2.0 + 3.0 = 5.0)
Not JIT-compatible
coalesce uses jax.ensure_compile_time_eval() internally and cannot be used inside jax.jit. Always call it before entering any JIT-compiled function.
When Do Duplicates Appear?¶
Duplicates are common when:
- Assembling finite element matrices — each element contributes to shared nodes
- Building matrices from edge lists — edges may be listed multiple times
- Merging sub-matrices — combining sparse blocks can create overlapping entries