-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
53 lines (41 loc) · 1.96 KB
/
api.py
File metadata and controls
53 lines (41 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import jax
from . import core, io
from .config import Config
from .structure import init_params as _init_params, init_state as _init_state
from .pretrain import pretrain
from dataclasses import dataclass
@dataclass(frozen=True)
class NCA:
config: Config
def __post_init__(self):
c = self.config
def _step(st, pr, key):
return core.step(st, pr, key, c)
def _rollout(st, pr, key, *, K=None):
k = c.k_default if K is None else int(K)
return core.rollout(st, pr, key, k, c)
def _inform(st, *, value, mode):
return io.inform(st, c, value=value, mode=mode)
def _extract(st):
return io.extract(st, c)
def _process(st, pr, key, x, *, K=None, mode="set"):
k = c.k_default if K is None else int(K)
st1 = _inform(st, value=x, mode=mode)
st2, key2 = _rollout(st1, pr, key, K=k)
out = _extract(st2)
return out, st2
def _get_overflow_penalty(st, *, bound: float = 5.0):
return core.get_overflow_penalty(st, c, bound=bound)
def _pretrain(pr, key, **kwargs):
return pretrain(key=key, params=pr, config=c, **kwargs)
object.__setattr__(self, "step", jax.jit(_step))
object.__setattr__(self, "rollout", jax.jit(_rollout, static_argnames=("K",)))
object.__setattr__(self, "inform", jax.jit(_inform, static_argnames=("mode",)))
object.__setattr__(self, "extract", jax.jit(_extract))
object.__setattr__(self, "process", jax.jit(_process, static_argnames=("K","mode")))
object.__setattr__(self, "pretrain", jax.jit(_pretrain, static_argnames=("steps","batch_size","K")))
object.__setattr__(self, "get_overflow_penalty", jax.jit(_get_overflow_penalty, static_argnames=("bound",)))
def init_params(self, key):
return _init_params(key, self.config)
def init_state(self, key):
return _init_state(key, self.config)