Currently, MPAX uses isinstance to distinguish between dense and sparse matrices. This could be improved by using jax.experimental.sparse.sparsify, which provides a more general and composable way to handle both types.
This depends on how mature the support is, but it’s worth exploring.
Related JAX issue: jax-ml/jax#28749
Currently,
MPAXuses isinstance to distinguish between dense and sparse matrices. This could be improved by usingjax.experimental.sparse.sparsify, which provides a more general and composable way to handle both types.This depends on how mature the support is, but it’s worth exploring.
Related JAX issue: jax-ml/jax#28749