-
Notifications
You must be signed in to change notification settings - Fork 33
Open
Description
Hi, I love the minimal nature of tinygp, but I am surprise it relies on a NN library such as Equinox. As far as I can tell, the only reason is due to the fact that the modules should be PyTrees, which equinox does nicely, but so does jax.
Since equinox requires jax>=0.4.38, the change would be trivial (compatibilty with jax<0.4.32 would be slightly more annoying).
It would go from:
import equinox as eqx
class GaussianProcess(eqx.Module):
...to pure jax/python
from dataclasses import dataclass
from abc import ABC, abstract_method
@jax.tree_util.register_dataclass
@dataclass
class GaussianProcess(ABC):
...but leaving jax as the only dependecy.
The only issue would be when the user inherits from a class, but a simple helper decorator like
def tinyclass(cls):
return jax.tree_util.register_dataclass(dataclass(cls))would just result in
@tinygp.tinyclass
class SpectralMixture(tinygp.kernels.Kernel):
...which is, afterall the jax approach.
If this is something you would interested I wouldn't mind doing a PR to keep tinygp...tiny :)
This was just a random observation.
Metadata
Metadata
Assignees
Labels
No labels