Skip to content

Replace Equinox PyTrees with normal dataclasses #248

@alonfnt

Description

@alonfnt

Hi, I love the minimal nature of tinygp, but I am surprise it relies on a NN library such as Equinox. As far as I can tell, the only reason is due to the fact that the modules should be PyTrees, which equinox does nicely, but so does jax.

Since equinox requires jax>=0.4.38, the change would be trivial (compatibilty with jax<0.4.32 would be slightly more annoying).

It would go from:

import equinox as eqx

class GaussianProcess(eqx.Module):
...

to pure jax/python

from dataclasses import dataclass
from abc import ABC, abstract_method

@jax.tree_util.register_dataclass
@dataclass
class GaussianProcess(ABC):
...

but leaving jax as the only dependecy.

The only issue would be when the user inherits from a class, but a simple helper decorator like

def tinyclass(cls):
    return jax.tree_util.register_dataclass(dataclass(cls))

would just result in

@tinygp.tinyclass
class SpectralMixture(tinygp.kernels.Kernel):
...

which is, afterall the jax approach.

If this is something you would interested I wouldn't mind doing a PR to keep tinygp...tiny :)

This was just a random observation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions