Creating a Pytorch solver for sparse linear systems
posted on 2020-10-17T11:04:18Z · view page on GitHubNot so long ago, I implemented a wrapper library in PyTorch to solve sparse linear systems on the CPU using the SuiteSparse routines under the hood. My goal is to eventually integrate this sparse solver into my photonic circuit simulator, Photontorch. However, for now, I thought it would be instructive to go over the steps I took to implement both the forward pass and the backward pass of such a custom PyTorch function.
So, with that in mind, the goal of the discussion below is to define a custom PyTorch function that solves the sparse linear system of the form $$ \begin{align} Ax &= b \end{align} $$ Where $A$ is sparse and $x$ and $b$ are dense.
A very simple system of equations: Scalar x Scalar¶
Let's start by implementing the $0D$ case. When $A$ is a $0D$ matrix (i.e. just a normal number or scalar). this system of equations has the following solution:
\begin{align*} x &= \frac{b}{A} \end{align*}Implementing custom functions (with custom forward and backward) in PyTorch is usually done by subclassing torch.autograd.Function and defining the forward and backward method for your derived class. For example in the case of our simple linear system with the above solution we can write:
from torch.autograd import Function
class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim > 0 and not (A.ndim==1 and A.shape[0]==1):
            raise ValueError("A should be 0D")
        ctx.save_for_backward(A, b)
        return b / A
    @staticmethod
    def backward(ctx, grad):
        A, b = ctx.saved_tensors
        return (-b/A**2)*grad, (1/A)*grad
How this works is as follows: the forward pass performs the operation and stores the relevant matrices in the context object ctx for later. This way these matrices can later on be used again during the gradient calculation in the backward pass. In the backward pass, the gradient of the loss $L$ must be calculated with respect to all the inputs of the forward pass ($A$ and $b$ in this case), given the gradients of the loss with respect to the output of the forward pass.
This means that in this case we need to find the following two transformations to describe the backward pass: $$ \begin{align*} \frac{\partial L}{\partial x} \rightarrow \frac{\partial L}{\partial A} \\ \frac{\partial L}{\partial x} \rightarrow \frac{\partial L}{\partial b} \end{align*} $$
By applying the chain rule, we have:
$$ \begin{align*} \frac{\partial L}{\partial A} = \frac{\partial L}{\partial x}\frac{\partial x}{\partial A} && \frac{\partial L}{\partial b} = \frac{\partial L}{\partial x}\frac{\partial x}{\partial b} \end{align*} $$Hence defining the backward pass comes down to determining $\frac{\partial x}{\partial A}$ and $\frac{\partial x}{\partial b}$:
$$ \begin{align*} \frac{\partial x}{\partial A} = - \frac{b}{A^2} && \frac{\partial x}{\partial b} = \frac{1}{A} \end{align*} $$These gradients are simple enough to know that we implemented them correctly. However, Pytorch also offers a gradcheck function to check wether we implemented the backward pass correctly:
from torch.autograd import gradcheck
A = torch.randn(1, requires_grad=True).squeeze()
b = torch.randn(1, requires_grad=True).squeeze()
gradcheck(Solve.apply, [A.double(), b.double()]) # gradcheck requires double precision
Here, we used Solve.apply to use the function during gradcheck. It is common in PyTorch to actually alias this to a lowercase function name:
solve = Solve.apply
A more general system: Matrix × Vector¶
Ok, what we implemented above was quite trivial, but mostly served as an example on how the forward and backward of custom functions are implemented in PyTorch. Now that's behind us we can go on to more interesting domains... general systems where A is a square, $2D$ matrix acting on a $1D$ vector.
As we know, the solution to the system of equation becomes in this case:
$$ \begin{align*} x = A^{-1} \, b \end{align*} $$However, calculating $A^{-1}$ is an expensive operation and should be avoided. Let's say, for the sake of our story that we want to use scipy.linalg.solve:
from scipy.linalg import solve as scipy_solve
Sidenote: Pytorch actually has a torch.solve function, which (in contrast to scipy.linalg.solve) works on CUDA GPUs as well. Hence in 99% of the cases this is the function you'll want. However, we go along here with scipy.linalg.solve as hopefully we'll learn something from writing the PyTorch wrapper. At the end of this post, we'll then implement our own sparse solver in C++.
class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 2 or (A.shape[0] != A.shape[1]):
            raise ValueError("A should be a square 2D matrix.")
        A_np = A.data.numpy()
        b_np = b.data.numpy()
        x_np = scipy_solve(A_np, b_np)
        x = torch.tensor(x_np, requires_grad=True)
        ctx.save_for_backward(A, b, x)
        return x
    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradb = Solve.apply(A.T, grad)
        gradA = -gradb[:,None] * x[None,:]
        return gradA, gradb
solve = Solve.apply
Where we used the following expressions for the gradients: $$ \begin{align*} \frac{\partial L}{\partial b} &= \frac{\partial L}{\partial x_i} \frac{\partial x_i}{\partial b_j} \\ &= \frac{\partial L}{\partial x_i} \frac{\partial}{\partial b_k} ( A^{-1}_{ij} b_j ) \\ &= \frac{\partial L}{\partial x_i} A^{-1}_{ij} \frac{\partial b_j}{\partial b_k}\\ &= \frac{\partial L}{\partial x_i} A^{-1}_{ij} \delta_{jk}\\ &= \frac{\partial L}{\partial x_i} A^{-1}_{ik} \\ &= \big(A^{-1}\big)^{T} \frac{\partial L}{\partial x} \\ &= \mathrm{solve}\big( A^T\,,\,\, \frac{\partial L}{\partial x} \big) \\\\ \frac{\partial L}{\partial A} &= \frac{\partial L}{\partial x_i} \frac{\partial x_i}{\partial A_{mn}} \\ &= \frac{\partial L}{\partial x_i} \frac{\partial}{\partial A_{mn}} ( A^{-1}_{ij} b_j ) \\ &= -\frac{\partial L}{\partial x_i} A^{-1}_{ij} \frac{\partial A_{jk}}{\partial A_{mn}} A^{-1}_{kl} b_l \\ &= -\frac{\partial L}{\partial x_i} A^{-1}_{ij} \delta_{jm} \delta_{kn} A^{-1}_{kl} b_l \\ &= -\frac{\partial L}{\partial x_i} A^{-1}_{im} A^{-1}_{nl} b_l \\ &= -\left(\big(A^{-1}\big)^{T} \frac{\partial L}{\partial x}\right)\otimes\left( A^{-1} b \right) \\ &= -\frac{\partial L}{\partial b} \otimes x \end{align*} $$ where we used the einstein summation convention during the derivations. We also used the following identity: $$ \frac{\partial (A^{-1})}{\partial p} = - A^{-1} \frac{\partial A}{\partial p} A^{-1}. $$
We can check that the backward pass was implemented correctly:
A = torch.randn(3,3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
gradcheck(solve, [A.double(), b.double()]) # gradcheck requires double precision
Even more general: Matrix × Matrix¶
In fact, scipy.linalg.solve also supports Matrix × Matrix systems, hence extending the backward method of the above solver is not so difficult:
class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 2 or (A.shape[0] != A.shape[1]):
            raise ValueError("A should be a square 2D matrix.")
        A_np = A.data.numpy()
        b_np = b.data.numpy()
        x_np = scipy_solve(A_np, b_np)
        x = torch.tensor(x_np, requires_grad=True)
        ctx.save_for_backward(A, b, x)
        return x
    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradb = Solve.apply(A.T, grad)
        gradA = -gradb @ x.T
        return gradA, gradb
solve = Solve.apply
Where we used a slightly different derivation as above, taking into account that $x$ and $b$ are $m\times n$ matrices (and $A$ remains an $m\times m$ matrix):
$$ \begin{align*} \frac{\partial L}{\partial b} &= \frac{\partial L}{\partial x_{ij}} \frac{\partial x_{ij}}{\partial b_{kl}} \\ &= \frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \frac{\partial b_{kj}}{\partial b_{lm}}\\ &= \frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \delta_{kl}\delta_{jm}\\ &= \frac{\partial L}{\partial x_{im}} A^{-1}_{il}\\ &= \big(A^{-1}\big)^{T} \frac{\partial L}{\partial x} \\ &= \mathrm{solve}\big( A^T\,,\,\, \frac{\partial L}{\partial x} \big) \\\\ \frac{\partial L}{\partial A} &= \frac{\partial L}{\partial x_{ij}} \frac{\partial x_{ij}}{\partial A_{mn}} \\ &= \frac{\partial L}{\partial x_{ij}} \frac{\partial}{\partial A_{mn}} ( A^{-1}_{ik} b_{kj} ) \\ &= -\frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \frac{\partial A_{kl}}{\partial A_{mn}} A^{-1}_{lp} b_{pj} \\ &= -\frac{\partial L}{\partial x_{ij}} A^{-1}_{ik} \delta_{km} \delta_{ln} A^{-1}_{lp} b_{pj} \\ &= -\frac{\partial L}{\partial x_{ij}} A^{-1}_{im} A^{-1}_{np} b_{pj} \\ &= -\left(\big(A^{-1}\big)^{T} \frac{\partial L}{\partial x}\right)\left( A^{-1} b \right)^T \\ &= -\frac{\partial L}{\partial b} x^T \end{align*} $$A = torch.randn(3,3, requires_grad=True)
b = torch.randn(3,2)
gradcheck(solve, [A.double(), b.double()]) # gradcheck requires double precision
Most general: Batched Matrix × Batched Matrix¶
The above Matrix $\times$ Matrix system is easily extended to a batched version (with the batch dimension being the first dimension):
class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 3 or (A.shape[1] != A.shape[2]):
            raise ValueError("A should be a batch of square 2D matrices with shape (b, m, m)")
        A_np = A.data.numpy()
        b_np = b.data.numpy()
        x_np = np.stack([scipy_solve(A_np[i], b_np[i]) for i in range(A.shape[0])], 0)
        x = torch.tensor(x_np, requires_grad=True)
        ctx.save_for_backward(A, b, x)
        return x
    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradb = Solve.apply(A.transpose(-1,-2), grad)
        gradA = -torch.bmm(gradb, x.transpose(-1,-2))
        return gradA, gradb
solve = Solve.apply
A = torch.randn(4,3,3, requires_grad=True)
b = torch.randn(4,3,2, requires_grad=True)
gradcheck(solve, [A.double(), b.double()]) # gradcheck requires double precision
Making a C++ extension¶
Let's now go on to make the above module in C++. However, we won't do this with the imported scipy function. For now, we'll just use torch.inverse when we need it. Just know, that after we get this simple C++ PyTorch extension to work, we'll swap out torch.inverse for our own C++ solver.
We'll start by making a file called solve.cpp, which includes the following two headers:
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
Generally speaking, torch/extension.h implements equivalent C++ functions to what torch offers in python, while ATen/ATen.h offers Python Tensor methods as C++ functions. The vector header is needed when returning more than one tensor (as we'll do in the backward).
Hence the batched matrix x batched matrix forward can be implemented in C++ as follows (remember we'll swap out torch::inverse for an actual solver later):
torch::Tensor solve_forward(torch::Tensor A, torch::Tensor b){
  auto result = torch::zeros_like(b);
  for (int i = 0; i < at::size(b, 0); i++){
      result[i] = torch::mm(torch::inverse(A[i]), b[i]); // we'll use an actual solver later.
  }
  return result;
}
Implementing the backward pass can also be done relatively easily:
std::vector<torch::Tensor> solve_backward(torch::Tensor grad, torch::Tensor A, torch::Tensor b, torch::Tensor x){
    auto gradb = at::transpose(solve_forward(at::transpose(A, -1, -2), grad), -1, -2);
    auto gradA = -torch::bmm(at::transpose(gradb, -1, -2), at::transpose(x, -1, -2));
    return {gradA, gradb};
}
Notice that we're returning a vector of tensors here, hence why we needed to include vector earlier.
Finally, we need to register the C++ functions defined as python functions. This is done with pybind:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &solve_forward, "solve forward");
  m.def("backward", &solve_backward, "solve backward");
}
Here the macro TORCH_EXTENSION_NAME will be replaced during compilation to the name of the torch extension defined in the setup.py file. The setup.py file looks as follows:
from setuptools import setup, Extension
from torch.utils import cpp_extension
solve_cpp = Extension(
    name="solve_cpp",
    sources=["solve.cpp"],
    include_dirs=cpp_extension.include_paths(),
    library_dirs=cpp_extension.library_paths(),
    extra_compile_args=[],
    libraries=[
        "c10",
        "torch",
        "torch_cpu",
        "torch_python",
    ],
    language="c++",
)
setup(
    name="solve",
    ext_modules=[solve_cpp],
    cmdclass={"build_ext": cpp_extension.BuildExtension},
)
The C++ extension can now be compiled as follows:
python setup.py install
Which will create a python executable (.so file on linux, .pyd file on windows) in your python's site-packages folder (i.e. it will be in your python path).
The only thing that's left is creating a thin wrapper in python:
import torch # always import torch BEFORE your custom torch extension
import solve_cpp # the custom torch extension we just created
class Solve(Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 3 or (A.shape[1] != A.shape[2]):
            raise ValueError("A should be a batch of square 2D matrices with shape (b, m, m)")
        if b.ndim != 3:
            raise ValueError("b should be a batch of matrices with shape (b, m, n)")
        x = solve_cpp.forward(A, b)
        ctx.save_for_backward(A, b, x)
        return x
    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradA, gradb = solve_cpp.backward(grad, A, b, x)
        return gradA, gradb
solve = Solve.apply
A = torch.randn(4,3,3, requires_grad=True)
b = torch.randn(4,3,2)
gradcheck(solve, [A.double(), b.double()]) # gradcheck requires double precision
Towards a sparse solver for CPU tensors¶
OK, so now we have a representation for the backward pass and we know how to make a Pytorch C++ extension. Let's go on to make an actual sparse solver.
To create such a sparse solver, we'll wrap the SuiteSparse routines in our C++ extension. In particular, we'll wrap the KLU algorithm provided by this library. The KLU sparse linear system solver is a very efficient solver for sparse matrices that arise from circuit simulation netlists. This means it will be most efficient for very sparse systems with often only one element per column of the (connection-)matrix. Obviously, it can tackle different linear systems of equations as well, but that's what it was originally intended for.
To use the KLU algorithm, we'll have to add the klu header from the SuiteSparse library to the three headers we had in the C++ extension before.
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <klu.h>
Note that with the Anaconda distribution you can simply do a
conda install suitesparse
to pull in all the SuiteSparse C++ libraries.
However, before we can go on to solving the sparse system with the KLU algorithm, there's one more hurdle to overcome:
Sparse COO -> Sparse CSC¶
The KLU algorithm expects the sparse matrix to be in CSC-format) rather than COO-format), the standard representation used in PyTorch. To do this conversion, we can have a look at how scipy does this conversion, and do something similar for our COO-pytorch tensors:
std::vector<at::Tensor> _coo_to_csc(int ncol, at::Tensor Ai, at::Tensor Aj, at::Tensor Ax) {
    int nnz = at::size(Ax, 0);
    at::TensorOptions options = at::TensorOptions().dtype(torch::kInt32).device(at::device_of(Ai));
    at::Tensor Bp = at::zeros(ncol+1, options);
    at::Tensor Bi = at::zeros_like(Ai);
    at::Tensor Bx = at::zeros_like(Ax);
    int* ai = Ai.data_ptr<int>();
    int* aj = Aj.data_ptr<int>();
    double* ax = Ax.data_ptr<double>();
    int* bp = Bp.data_ptr<int>();
    int* bi = Bi.data_ptr<int>();
    double* bx = Bx.data_ptr<double>();
    //compute number of non-zero entries per row of A
    for (int n = 0; n < nnz; n++) {
        bp[aj[n]] += 1;
    }
    //cumsum the nnz per row to get Bp
    int cumsum = 0;
    int temp = 0;
    for(int j = 0; j < ncol; j++) {
        temp = bp[j];
        bp[j] = cumsum;
        cumsum += temp;
    }
    bp[ncol] = nnz;
    //write Ai, Ax into Bi, Bx
    int col = 0;
    int dest = 0;
    for(int n = 0; n < nnz; n++) {
        col = aj[n];
        dest = bp[col];
        bi[dest] = ai[n];
        bx[dest] = ax[n];
        bp[col] += 1;
    }
    int last = 0;
    for(int i = 0; i <= ncol; i++) {
        temp = bp[i];
        bp[i] = last;
        last = temp;
    }
    return {Bp, Bi, Bx};
}
Note that we're converting the PyTorch CPU tensor to a C-array (by asking for its pointer). This means that this conversion will be CPU-only. However, performing this conversion on native pytorch tensors would be a lot slower.
This function returns three pytorch tensors: Bp: the column pointers, Bi: the indices in each column and Bx: the values of the sparse tensor.
KLU Solver¶
Using these three vectors, one can define a KLU solver, by wrapping the KLU routines as follows:
void _klu_solve(at::Tensor Ap, at::Tensor Ai, at::Tensor Ax, at::Tensor b) {
    int ncol = at::size(Ap, 0) - 1;
    int nb = at::size(b, 0);
    int* ap = Ap.data_ptr<int>();
    int* ai = Ai.data_ptr<int>();
    double* ax = Ax.data_ptr<double>();
    double* bb = b.data_ptr<double>();
    klu_symbolic* Symbolic;
    klu_numeric* Numeric;
    klu_common Common;
    klu_defaults(&Common);
    Symbolic = klu_analyze(ncol, ap, ai, &Common);
    Numeric = klu_factor(ap, ai, ax, Symbolic, &Common);
    klu_solve(Symbolic, Numeric, ncol, nb/ncol, bb, &Common);
    klu_free_symbolic(&Symbolic, &Common);
    klu_free_numeric(&Numeric, &Common);
}
Using the KLU algorithms comes down to first doing a symbolic analyzation and factorization of the sparse matrix A (i.e. Ap, Ai and Ax), probably to determine the sparsity pattern after which the system is solved with klu_solve. Note that this is an inplace operation on b, i.e. after solving, the solution x will be in the b tensor.
Updated Forward¶
Finally, we can update the forward method by using our _klu_solve wrapper in stead of torch::inverse:
at::Tensor solve_forward(at::Tensor A, at::Tensor b) {
    int p = at::size(b, 0);
    int m = at::size(b, 1);
    int n = at::size(b, 2);
    at::Tensor bflat = at::clone(at::reshape(at::transpose(b, 1, 2), {p, m*n}));
    at::Tensor Ax = at::reshape(A._values(), {p, -1});
    at::Tensor Ai = at::reshape(at::_cast_Int(A._indices()[1]), {p, -1});
    at::Tensor Aj = at::reshape(at::_cast_Int(A._indices()[2]), {p, -1});
    for (int i = 0; i < p; i++) {
        std::vector<at::Tensor> Ap_Ai_Ax = _coo_to_csc(m, Ai[i], Aj[i], Ax[i]);
        _klu_solve(Ap_Ai_Ax[0], Ap_Ai_Ax[1], Ap_Ai_Ax[2], bflat[i]); // result will be in bflat
    }
    return at::transpose(bflat.view({p,n,m}), 1, 2);
}
Updated Setup¶
The setup.py for this extension also becomes a bit more complex, as the SuiteSparse libraries need to be included:
import os
import glob
from setuptools import setup, Extension
from torch.utils import cpp_extension
libroot = os.path.dirname(os.path.dirname(os.__file__))
if os.name == "nt":  # Windows
    suitesparse_lib = os.path.join(libroot, "Library", "lib")
    suitesparse_include = os.path.join(libroot, "Library", "include", "suitesparse")
else:  # Linux / Mac OS
    suitesparse_lib = os.path.join(os.path.dirname(libroot), "lib")
    suitesparse_include = os.path.join(os.path.dirname(libroot), "include")
torch_sparse_solve_cpp = Extension(
    name="torch_sparse_solve_cpp",
    sources=["torch_sparse_solve.cpp"],
    include_dirs=[*cpp_extension.include_paths(), suitesparse_include],
    library_dirs=[*cpp_extension.library_paths(), suitesparse_lib],
    extra_compile_args=[],
    libraries=[
        "c10",
        "torch",
        "torch_cpu",
        "torch_python",
        "klu",
        "btf",
        "amd",
        "colamd",
        "suitesparseconfig",
    ],
    language="c++",
)
setup(
    name="torch_sparse_solve",
    ext_modules=[torch_sparse_solve_cpp],
    cmdclass={"build_ext": cpp_extension.BuildExtension},
)
Updated python wrapper¶
We can update the python wrapper to include the newly built C++ extension:
from torch_sparse_solve_cpp import solve_forward, solve_backward
class Solve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        if A.ndim != 3 or (A.shape[1] != A.shape[2]) or not A.is_sparse:
            raise ValueError(
                "'A' should be a batch of square 2D sparse matrices with shape (b, m, m)."
            )
        if b.ndim != 3:
            raise ValueError("'b' should be a batch of matrices with shape (b, m, n).")
        if not A.dtype == torch.float64:
            raise ValueError("'A' should be a sparse float64 tensor (for now). Please first convert to float64.")
        if not b.dtype == torch.float64:
            raise ValueError("'b' should be a float64 tensor (for now). Please first convert to float64")
        x = solve_forward(A, b)
        ctx.save_for_backward(A, b, x)
        return x
    @staticmethod
    def backward(ctx, grad):
        A, b, x = ctx.saved_tensors
        gradA, gradb = solve_backward(grad, A, b, x)
        return gradA, gradb
    
solve = Solve.apply
mask = torch.tensor([[[1, 0, 0], [1, 1, 0], [0, 0, 1]]], dtype=torch.float64)
A = mask * torch.randn(4, 3, 3, dtype=torch.float64)
Asp = A.to_sparse()
Asp.requires_grad_()
b = torch.randn(4, 3, 2, dtype=torch.float64, requires_grad=True)
gradcheck(solve, [Asp, b], check_sparse_nnz=True) 
That's it!¶
Those were the steps I went through creating my first PyTorch C++ extension. Please check it out on GitHub: https://github.com/flaport/torch_sparse_solve and consider giving it a star 😉
If you like this post, consider leaving a comment or star it on GitHub.