-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstructure.py
More file actions
116 lines (89 loc) · 3.58 KB
/
structure.py
File metadata and controls
116 lines (89 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jax.nn.initializers import glorot_uniform
from .config import Config
from .utils import _circle_layout, _default_output_nodes
# -- state --
@jax.tree_util.register_pytree_node_class
@dataclass(slots=True, frozen=True)
class State:
grid: jnp.ndarray # shape = (C, grid_size, grid_size)
def tree_flatten(self):
return ((self.grid,), None)
@classmethod
def tree_unflatten(cls, aux, children):
(grid,) = children
return cls(grid=grid)
@property
def shape(self) -> tuple[int, int, int]:
return self.grid.shape
@jax.tree_util.register_pytree_node_class
@dataclass(slots=True, frozen=True)
class Params:
'''
MLP is two dense layers applied channel-wise
'''
# learned conv frontend
conv_w: jnp.ndarray
conv_b: jnp.ndarray
w1: jnp.ndarray # (in_dim, hidden)
b1: jnp.ndarray # (hidden,)
w2: jnp.ndarray # (hidden, out_dim)
b2: jnp.ndarray # (out_dim,)
gain: jnp.ndarray # (config.C,)
def tree_flatten(self):
return ((self.conv_w, self.conv_b, self.w1, self.b1, self.w2, self.b2, self.gain), None)
@classmethod
def tree_unflatten(cls, aux, children):
conv_w, conv_b, w1, b1, w2, b2, gain = children
return cls(conv_w=conv_w, conv_b=conv_b, w1=w1, b1=b1, w2=w2, b2=b2, gain=gain)
@property
def sizes(self):
return (self.conv_w.shape, self.conv_b.shape, self.w1.shape, self.b1.shape, self.w2.shape, self.b2.shape, self.gain.shape)
# -- initializers --
def init_state(key: jax.Array, config: Config) -> State:
''' init grid state '''
C, N = config.C, config.grid_size
g = jnp.zeros((C, N, N), dtype=config.dtype)
in_idx = config.idx_in_flag
out_idx = config.idx_out_flag
# input nodes: circle layout
inp_xy = _circle_layout(N, config.num_input_nodes)
g = g.at[in_idx, inp_xy[:,1], inp_xy[:,0]].set(1.0)
# per paper: at input cells, hidden channels start at 1 (and remain fixed by core freeze)
H = config.hidden_channels
if H > 0:
x = inp_xy[:, 0]
y = inp_xy[:, 1]
for c in range(1, 1 + H):
g = g.at[c, y, x].set(1.0)
# output nodes near center
out_xy = _default_output_nodes(N, config.num_output_nodes)
g = g.at[out_idx, out_xy[:,1], out_xy[:,0]].set(1.0)
return State(grid=g)
def init_params(key: jax.Array, config: Config) -> Params:
''' init params '''
in_dim = config.input_feats_per_cell
hidden = config.hidden
out_dim = config.C
k1, k2, k3, k4 = jax.random.split(key, 4)
# learned 3x3 conv filters if requested
if config.perception == 'learned3x3':
kf = config.conv_features
conv_w = jax.random.normal(k1, (kf, config.C, 3, 3), dtype=config.dtype) * jnp.sqrt(2.0 / (9*config.C))
conv_b = jnp.zeros((kf,), dtype=config.dtype)
else:
conv_w = jnp.zeros((0,), dtype=config.dtype)
conv_b = jnp.zeros((0,), dtype=config.dtype)
w1 = glorot_uniform()(k3, (in_dim, hidden), dtype=config.dtype)
b1 = jnp.zeros((hidden,), dtype=config.dtype)
w2 = glorot_uniform()(k4, (hidden, out_dim), dtype=config.dtype)
b2 = jnp.zeros((out_dim,), dtype=config.dtype)
gain = jnp.full((out_dim,), 0.5, dtype=config.dtype)
gain = gain.at[config.idx_in_flag].set(0.0)
gain = gain.at[config.idx_out_flag].set(0.0)
return Params(conv_w=conv_w, conv_b=conv_b, w1=w1, b1=b1, w2=w2, b2=b2, gain=gain)
# -- utils --
def num_params(p: Params) -> int:
return sum(int(jnp.size(arr)) for arr in (p.conv_w, p.conv_b, p.w1, p.b1, p.w2, p.b2, p.gain))