Hi, great work on the package!
The block transition matrices implemented for quasiseparable matrices are a neat optimisation but I've noticed that some operations don't seem to have been modified to deal with them. I give a few minimal examples here, in practice they cause a lot of headaches for some multi-wavelength light curve fitting I'm doing and also for GP optimisations I'm implementing which work with quasiseparable matrices directly. I'm not sure what version these specifically became an issue but it's certainly an issue with the latest versions.
For this example I'm running:
- tinygp 0.3.1
- jax, jaxlib 0.9.2
import tinygp
import jax.numpy as jnp
N_t = 100
t = jnp.linspace(-10., 10., N_t)
# k_beat has Block transition matrices
k_beat = tinygp.kernels.quasisep.Cosine(1.) + tinygp.kernels.quasisep.Cosine(2.)
banded_term = tinygp.noise.Banded(diag=jnp.ones(N_t), off_diags=jnp.ones((N_t, 1)))
# addition of two QSM where at least one of them has Block transition matrices fails
gp1 = tinygp.GaussianProcess(k_beat, t, noise=banded_term) # breaks
k_prod = k_beat * tinygp.kernels.quasisep.Exp(1.)
# product of two kernels where at least one of them has Block transition matrices fails
gp2 = tinygp.GaussianProcess(k_prod, t, diag=jnp.ones(N_t)) # breaks
gp3 = tinygp.GaussianProcess(k_beat * k_beat, t, diag=jnp.ones(N_t)) # breaks
Crash 1: Adding QSMs with Block transition matrices
The first crash gives the error TypeError: Cannot determine dtype of Block(blocks=(f32[2,2], f32[2,2])) which corresponds to:
211 p1, q1, a1 = self
212 p2, q2, a2 = other
213 return StrictLowerTriQSM(
214 p=jnp.concatenate((p1, p2)),
215 q=jnp.concatenate((q1, q2)),
216 a=block_diag(a1, a2), # <-- fails here
217 )
Crash 2: Product kernel with Block transition matrices
The second crash gives the error TypeError: dot_general requires contracting dimensions to have the same shape, got (0,) and (4,). which corresponds to Quasisep.to_symm_qsm:
95 h = jax.vmap(self.observation_model)(X)
96 q = h
97 p = h @ Pinf # <-- fails here
98 d = jnp.sum(p * q, axis=1)
99 p = jax.vmap(lambda x, y: x @ y)(p, a)
Crash 3: Product of two Sum kernels
The third crash also happens when building a symmetric QSM:
89 def to_symm_qsm(self, X: JAXArray) -> SymmQSM:
90 """The symmetric quasiseparable representation of this kernel"""
91 Pinf = self.stationary_covariance() # <-- enters Product.stationary_covariance
92 a = jax.vmap(self.transition_matrix)(
93 jax.tree_util.tree_map(lambda y: jnp.append(y[0], y[:-1]), X), X
94 )
95 h = jax.vmap(self.observation_model)(X)
which calls into _prod_helper:
273 def stationary_covariance(self) -> JAXArray:
274 return _prod_helper(
275 self.kernel1.stationary_covariance(),
276 self.kernel2.stationary_covariance(),
277 )
639 return a1[i] * a2[j]
640 elif a1.ndim == 2:
641 return a1[i[:, None], i[None, :]] * a2[j[:, None], j[None, :]] # <-- fails here
642 else:
643 raise NotImplementedError
which ultimately hits:
47 @jax.jit
48 def __mul__(self, other: Any) -> "Block":
49 return Block(*(b * other for b in self.blocks))
# TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'Block'
Ideally I would like if there were an option to turn off the formation of block matrices (as suggested in PR #240), but the computational savings they can offer is useful and I'd imagine there shouldn't be any fundamental issue with updating the addition and multiplication rules of kernels to account for Block transition matrices, so that would be a nicer long-term fix.
Thanks again for all the great work!
Hi, great work on the package!
The block transition matrices implemented for quasiseparable matrices are a neat optimisation but I've noticed that some operations don't seem to have been modified to deal with them. I give a few minimal examples here, in practice they cause a lot of headaches for some multi-wavelength light curve fitting I'm doing and also for GP optimisations I'm implementing which work with quasiseparable matrices directly. I'm not sure what version these specifically became an issue but it's certainly an issue with the latest versions.
For this example I'm running:
Crash 1: Adding QSMs with Block transition matrices
The first crash gives the error
TypeError: Cannot determine dtype of Block(blocks=(f32[2,2], f32[2,2]))which corresponds to:Crash 2: Product kernel with Block transition matrices
The second crash gives the error
TypeError: dot_general requires contracting dimensions to have the same shape, got (0,) and (4,).which corresponds toQuasisep.to_symm_qsm:Crash 3: Product of two Sum kernels
The third crash also happens when building a symmetric QSM:
which calls into
_prod_helper:which ultimately hits:
Ideally I would like if there were an option to turn off the formation of block matrices (as suggested in PR #240), but the computational savings they can offer is useful and I'd imagine there shouldn't be any fundamental issue with updating the addition and multiplication rules of kernels to account for Block transition matrices, so that would be a nicer long-term fix.
Thanks again for all the great work!