Source code for klujax

""" klujax: a KLU solver for JAX """

# Metadata ============================================================================

__version__ = "0.3.1"
__author__ = "Floris Laporte"
__all__ = ["solve", "coo_mul_vec"]

# Imports =============================================================================

import sys
from functools import partial

import jax
import jax.extend
import jax.numpy as jnp
import numpy as np
from jax import core, lax
from jax.core import ShapedArray
from jax.interpreters import ad, batching, mlir
from jaxtyping import Array

import klujax_cpp

# Config ==============================================================================

DEBUG = False
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
_log = lambda s: None if not DEBUG else print(s, file=sys.stderr)  # noqa: E731

# The main functions ==================================================================


[docs]@jax.jit def solve(Ai: Array, Aj: Array, Ax: Array, b: Array) -> Array: """Solve for x in the sparse linear system Ax=b. Args: Ai: [n_nz; int32]: the row indices of the sparse matrix A Aj: [n_nz; int32]: the column indices of the sparse matrix A Ax: [n_lhs? x n_nz; float64|complex128]: the values of the sparse matrix A b: [n_lhs? x n_col x n_rhs?; float64|complex128]: the target vector Returns: x: the result (x≈A^-1b) """ if any(x.dtype in COMPLEX_DTYPES for x in (Ax, b)): result = solve_c128.bind( Ai.astype(jnp.int32), Aj.astype(jnp.int32), Ax.astype(jnp.complex128), b.astype(jnp.complex128), ) else: result = solve_f64.bind( Ai.astype(jnp.int32), Aj.astype(jnp.int32), Ax.astype(jnp.float64), b.astype(jnp.float64), ) return result # type: ignore
[docs]@jax.jit def coo_mul_vec(Ai: Array, Aj: Array, Ax: Array, x: Array) -> Array: """Multiply a sparse matrix with a vector: Ax=b Args: Ai: [n_nz; int32]: the row indices of the sparse matrix A Aj: [n_nz; int32]: the column indices of the sparse matrix A Ax: [n_lhs? x n_nz; float64|complex128]: the values of the sparse matrix A x: [n_lhs? x n_col x n_rhs?; float64|complex128]: the vector that's being multiplied by A Returns: b: the result of the multiplication (b=Ax) """ if any(x.dtype in COMPLEX_DTYPES for x in (Ax, x)): result = coo_mul_vec_c128.bind( Ai.astype(jnp.int32), Aj.astype(jnp.int32), Ax.astype(jnp.complex128), x.astype(jnp.complex128), ) else: result = coo_mul_vec_f64.bind( Ai.astype(jnp.int32), Aj.astype(jnp.int32), Ax.astype(jnp.float64), x.astype(jnp.float64), ) return result # type: ignore
# Constants =========================================================================== COMPLEX_DTYPES = ( np.complex64, np.complex128, # np.complex256, jnp.complex64, jnp.complex128, ) # Primitives ========================================================================== solve_f64 = core.Primitive("solve_f64") solve_c128 = core.Primitive("solve_c128") coo_mul_vec_f64 = core.Primitive("coo_mul_vec_f64") coo_mul_vec_c128 = core.Primitive("coo_mul_vec_c128") # Register XLA extensions ============================================================== jax.extend.ffi.register_ffi_target( "_solve_f64", klujax_cpp.solve_f64(), platform="cpu", ) jax.extend.ffi.register_ffi_target( "_coo_mul_vec_f64", klujax_cpp.coo_mul_vec_f64(), platform="cpu", ) jax.extend.ffi.register_ffi_target( "_solve_c128", klujax_cpp.solve_c128(), platform="cpu", ) jax.extend.ffi.register_ffi_target( "_coo_mul_vec_c128", klujax_cpp.coo_mul_vec_c128(), platform="cpu", ) # Helper Decorators =================================================================== def ad_register(primitive): def decorator(fun): ad.primitive_jvps[primitive] = fun return fun return decorator def transpose_register(primitive): def decorator(fun): ad.primitive_transposes[primitive] = fun return fun return decorator def vmap_register(primitive, operation): def decorator(fun): batching.primitive_batchers[primitive] = partial(fun, operation) return fun return decorator # Implementations ===================================================================== @solve_f64.def_impl def solve_f64_impl(Ai, Aj, Ax, b) -> jnp.ndarray: Ai, Aj, Ax, b, shape = _prepare_arguments(Ai, Aj, Ax, b) _b = b.transpose(0, 2, 1) call = jax.extend.ffi.ffi_call( "_solve_f64", jax.ShapeDtypeStruct(_b.shape, _b.dtype), vmap_method="broadcast_all", ) result = call( # type: ignore Ai, Aj, Ax, _b, ) return result.transpose(0, 2, 1).reshape(*shape) # type: ignore @solve_c128.def_impl def solve_c128_impl(Ai, Aj, Ax, b) -> jnp.ndarray: Ai, Aj, Ax, b, shape = _prepare_arguments(Ai, Aj, Ax, b) _b = b.transpose(0, 2, 1) call = jax.extend.ffi.ffi_call( "_solve_c128", jax.ShapeDtypeStruct(_b.shape, _b.dtype), vmap_method="broadcast_all", ) result = call( # type: ignore Ai, Aj, Ax, _b, ) return result.transpose(0, 2, 1).reshape(*shape) # type: ignore @coo_mul_vec_f64.def_impl def coo_mul_vec_f64_impl(Ai, Aj, Ax, x) -> jnp.ndarray: Ai, Aj, Ax, x, shape = _prepare_arguments(Ai, Aj, Ax, x) _x = x.transpose(0, 2, 1) call = jax.extend.ffi.ffi_call( "_coo_mul_vec_f64", jax.ShapeDtypeStruct(_x.shape, _x.dtype), vmap_method="broadcast_all", ) result = call( # type: ignore Ai, Aj, Ax, _x, ) return result.transpose(0, 2, 1).reshape(*shape) # type: ignore @coo_mul_vec_c128.def_impl def coo_mul_vec_c128_impl(Ai, Aj, Ax, x) -> jnp.ndarray: Ai, Aj, Ax, x, shape = _prepare_arguments(Ai, Aj, Ax, x) _x = x.transpose(0, 2, 1) call = jax.extend.ffi.ffi_call( "_coo_mul_vec_c128", jax.ShapeDtypeStruct(_x.shape, _x.dtype), vmap_method="broadcast_all", ) result = call( # type: ignore Ai, Aj, Ax, _x, ) return result.transpose(0, 2, 1).reshape(*shape) # type: ignore def _prepare_arguments(Ai, Aj, Ax, x): Ai = jnp.asarray(Ai) Aj = jnp.asarray(Aj) Ax = jnp.asarray(Ax) x = jnp.asarray(x) shape = x.shape _log(f"{Ai.dtype=}") _log(f"{Aj.dtype=}") _log(f"{Ai.max()=}") _log(f"{Aj.max()=}") _log(f"{Ax.dtype=}") _log(f"{x.dtype=}") _log(f"{Ax.shape=}") _log(f"{x.shape=}") prefer_x_rhs_over_lhs = (Ax.ndim < 2) or (Ax.shape[0] != x.shape[0]) _log(f"{prefer_x_rhs_over_lhs=}") if Ax.ndim > 2: raise ValueError( f"Ax should be at most 2D with shape: (n_lhs, n_nz). Got: {Ax.shape}. " "Note: jax.vmap is supported. Use it if needed." ) else: Ax = jnp.atleast_2d(Ax) a_n_lhs, n_nz = Ax.shape Ax = Ax.reshape(-1, n_nz) _log(f"{Ax.shape=}") _log(f"{n_nz=}") if x.ndim == 0: x = x[None, None, None] x_n_lhs, n_col, n_rhs = x.shape shape_includes_lhs = False elif x.ndim == 1: x = x[None, :, None] x_n_lhs, n_col, n_rhs = x.shape shape_includes_lhs = False elif x.ndim == 2 and prefer_x_rhs_over_lhs: x = x[None, :, :] x_n_lhs, n_col, n_rhs = x.shape shape_includes_lhs = False elif x.ndim == 2 and not prefer_x_rhs_over_lhs: x = x[:, :, None] x_n_lhs, n_col, n_rhs = x.shape shape_includes_lhs = True elif x.ndim == 3: x_n_lhs, n_col, n_rhs = x.shape shape_includes_lhs = True else: raise ValueError( f"x should be at most 3D with shape: (n_lhs, n_col, n_rhs). Got: {x.shape}. " "Note: jax.vmap is supported. Use it if needed." ) _log(f"{x.shape=}") _log(f"{n_col=}") _log(f"{n_rhs=}") if a_n_lhs == x_n_lhs: n_lhs = a_n_lhs elif a_n_lhs > x_n_lhs: if not x_n_lhs == 1: raise ValueError( f"Cannot broadcast n_lhs for x into n_lhs for Ax. " f"Got: n_lhs[x]={x_n_lhs}; n_lhs[Ax]={a_n_lhs}." ) n_lhs = a_n_lhs if shape_includes_lhs: shape = (n_lhs,) + shape[1:] else: shape = (n_lhs,) + shape else: if not a_n_lhs == 1: raise ValueError( f"Cannot broadcast n_lhs for Ax into n_lhs for x. " f"Got: n_lhs[x]={x_n_lhs}; n_lhs[Ax]={a_n_lhs}." ) n_lhs = x_n_lhs _log(f"{n_lhs=}") Ax = jnp.broadcast_to(Ax, (n_lhs, n_nz)) x = jnp.broadcast_to(x, (n_lhs, n_col, n_rhs)) _log(f"{Ax.shape=}") _log(f"{x.shape=}") # We retain the old shape of b so the result of the primitive can be # reshaped to the expected shape. return Ai, Aj, Ax, x, shape # Lowerings =========================================================================== solve_f64_lowering = mlir.lower_fun(solve_f64_impl, multiple_results=False) mlir.register_lowering(solve_f64, solve_f64_lowering) solve_c128_lowering = mlir.lower_fun(solve_c128_impl, multiple_results=False) mlir.register_lowering(solve_c128, solve_c128_lowering) coo_mul_vec_f64_lowering = mlir.lower_fun(coo_mul_vec_f64_impl, multiple_results=False) mlir.register_lowering(coo_mul_vec_f64, coo_mul_vec_f64_lowering) coo_mul_vec_c128_lowering = mlir.lower_fun( coo_mul_vec_c128_impl, multiple_results=False ) mlir.register_lowering(coo_mul_vec_c128, coo_mul_vec_c128_lowering) # Abstract Evaluations ================================================================ @solve_f64.def_abstract_eval @solve_c128.def_abstract_eval @coo_mul_vec_f64.def_abstract_eval @coo_mul_vec_c128.def_abstract_eval def coo_vec_operation_abstract_eval(Ai, Aj, Ax, b): return ShapedArray(b.shape, b.dtype) # Forward Gradients =================================================================== @ad_register(solve_f64) @ad_register(solve_c128) def solve_value_and_jvp(arg_values, arg_tangents): Ai, Aj, Ax, b = arg_values dAi, dAj, dAx, db = arg_tangents dAx = dAx if not isinstance(dAx, ad.Zero) else lax.zeros_like_array(Ax) dAi = dAi if not isinstance(dAi, ad.Zero) else lax.zeros_like_array(Ai) dAj = dAj if not isinstance(dAj, ad.Zero) else lax.zeros_like_array(Aj) db = db if not isinstance(db, ad.Zero) else lax.zeros_like_array(b) x = solve(Ai, Aj, Ax, b) dA_x = coo_mul_vec(Ai, Aj, dAx, x) invA_dA_x = solve(Ai, Aj, Ax, dA_x) invA_db = solve(Ai, Aj, Ax, db) return x, -invA_dA_x + invA_db @ad_register(coo_mul_vec_f64) @ad_register(coo_mul_vec_c128) def coo_mul_vec_value_and_jvp(arg_values, arg_tangents): Ai, Aj, Ax, b = arg_values dAi, dAj, dAx, db = arg_tangents dAx = dAx if not isinstance(dAx, ad.Zero) else lax.zeros_like_array(Ax) dAi = dAi if not isinstance(dAi, ad.Zero) else lax.zeros_like_array(Ai) dAj = dAj if not isinstance(dAj, ad.Zero) else lax.zeros_like_array(Aj) db = db if not isinstance(db, ad.Zero) else lax.zeros_like_array(b) x = coo_mul_vec(Ai, Aj, Ax, b) dA_b = coo_mul_vec(Ai, Aj, dAx, b) A_db = coo_mul_vec(Ai, Aj, Ax, db) return x, dA_b + A_db # Backward Gradients through Transposition ============================================ @transpose_register(solve_f64) @transpose_register(solve_c128) def solve_transpose(ct, Ai, Aj, Ax, b): assert not ad.is_undefined_primal(Ai) assert not ad.is_undefined_primal(Aj) assert not ad.is_undefined_primal(Ax) assert not ad.is_undefined_primal(Ax) assert ad.is_undefined_primal(b) return None, None, None, solve(Aj, Ai, Ax.conj(), ct) # = inv(A).H@ct [= ct@inv(A)] @transpose_register(coo_mul_vec_f64) @transpose_register(coo_mul_vec_c128) def coo_mul_vec_transpose(ct, Ai, Aj, Ax, b): assert not ad.is_undefined_primal(Ai) assert not ad.is_undefined_primal(Aj) assert ad.is_undefined_primal(Ax) != ad.is_undefined_primal(b) # xor if ad.is_undefined_primal(b): return None, None, None, coo_mul_vec(Aj, Ai, Ax.conj(), ct) # = A.T@ct [= ct@A] else: dA = ct[Ai] * b[Aj] dA = dA.reshape(dA.shape[0], -1).sum(-1) # not sure about this... return None, None, dA, None # Vectorization (vmap) ================================================================ @vmap_register(solve_f64, solve) @vmap_register(solve_c128, solve) @vmap_register(coo_mul_vec_f64, coo_mul_vec) @vmap_register(coo_mul_vec_c128, coo_mul_vec) def coo_vec_operation_vmap(operation, vector_arg_values, batch_axes): aAi, aAj, aAx, ab = batch_axes Ai, Aj, Ax, b = vector_arg_values assert aAi is None, "Ai cannot be vectorized." assert aAj is None, "Aj cannot be vectorized." if aAx is not None and ab is not None: assert isinstance(aAx, int) and isinstance(ab, int) Ax = jnp.moveaxis(Ax, aAx, 0) # treat as lhs b = jnp.moveaxis(b, ab, 0) # treat as lhs result = operation(Ai, Aj, Ax, b) return result, 0 if ab is None: assert isinstance(aAx, int) Ax = jnp.moveaxis(Ax, aAx, 0) # treat as lhs b = jnp.broadcast_to(b[None], (Ax.shape[0], *b.shape)) result = operation(Ai, Aj, Ax, b) return result, 0 if aAx is None: assert isinstance(ab, int) _log(f"vmap: {b.shape=}") b = jnp.moveaxis(b, ab, -1) # treat as rhs _log(f"vmap: {b.shape=}") shape = b.shape if b.ndim == 0: b = b[None, None, None] elif b.ndim == 1: b = b[None, None, :] elif b.ndim == 2: b = b[None, :, :] elif b.ndim == 3: b = b[:, :, :] b = b.reshape(b.shape[0], b.shape[1], -1) _log(f"vmap: {b.shape=}") # b is now guaranteed to have shape (n_lhs, n_col, n_rhs) result = operation(Ai, Aj, Ax, b) result = result.reshape(*shape) return result, -1 raise ValueError("invalid arguments for vmap")