forked from TuGraph-family/TuGraph-AntGraphLearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsubgraph.py
More file actions
251 lines (195 loc) · 8.47 KB
/
subgraph.py
File metadata and controls
251 lines (195 loc) · 8.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#!/usr/bin/python
# coding: utf-8
from typing import List
import numpy as np
from pyagl.pyagl import (
AGLDType,
DenseFeatureSpec,
SparseKVSpec,
SparseKSpec,
NodeSpec,
EdgeSpec,
SubGraph,
NDArray,
)
class PySubGraph:
"""PySubGraph
call c++ subgraph to parse pbs, and get edge_index, node feature, edge feature from the subgraph.
>>> node_spec = NodeSpec("n", AGLDType.STR)
>>> edge_spec = EdgeSpec("e", node_spec, node_spec, AGLDType.STR)
>>> edge_spec.AddDenseSpec("time", DenseFeatureSpec("time", 1, AGLDType.INT64))
>>> sg = PySubGraph([node_spec], [edge_spec])
>>> sg.from_pb(["xxx"])
>>> print(sg.get_edge_index())
>>> {"e": (xxx, xxx, xxx, xxx, xxx)} # {edge_name: (row_ptr, col_indices, edge_indices, row_num, col_num)}
"""
def __init__(self, node_specs: List[NodeSpec], edge_specs: List[EdgeSpec]):
"""
Args:
node_specs(List[NodeSpec]): list of node specs, each type of node has a node_spec
edge_specs(List[EdgeSpec]): list of edge specs, each type of edge has a edge spec
"""
self.sg = SubGraph()
self.n_specs = node_specs
self.e_specs = edge_specs
for n_spec in node_specs:
self.sg.AddNodeSpec(n_spec.GetNodeName(), n_spec)
for e_spec in edge_specs:
self.sg.AddEdgeSpec(e_spec.GetEdgeName(), e_spec)
def from_pb(
self,
graph_features: List[str],
is_merge: bool = False,
uncompress: bool = False,
):
"""from_pb pass list of pb strings to c++, to parse pbs into a subgraph
Args:
graph_features(List[str]): a list of pb strings, each represents a subgraph wrt. to a certain
or a set of nodes.
is_merge(bool): Default is False. Whether merge those subgraphs into one subgraph.
now only support disjoint merge (is_merge=False).
uncompress(bool): whether the pb strings should be um-compressed by default compress algorithm (i.e., gzip)
"""
# parse pb graph features
self.sg.CreateFromPB(graph_features, is_merge, uncompress)
def from_pb_bytes(
self,
graph_features: List[bytearray],
is_merge: bool = False,
uncompress: bool = False,
):
"""from_pb_bytes efficiently pass a list of pb bytearray to c++, to parse pbs into a subgraph
Args:
graph_features(List[bytearray]): a list of pb strings (encoded with utf-8 to efficiently pass to c++),
each represents a subgraph wrt. to a certain or a set of nodes.
is_merge(bool): Default is False. Whether merge those subgraphs into one subgraph.
now only support disjoint merge (is_merge=False).
uncompress(bool): whether the pb strings should be um-compressed by default compress algorithm (i.e., gzip)
"""
self.sg.CreateFromPBBytesArray(graph_features, is_merge, uncompress)
def get_edge_index(self):
"""
Returns: {edge_name: (row_ptr, col_indices, edge_indices, row_num, col_num)}. A dict of edge indexes wrt.
to their types.
"""
result = {}
for e_spec in self.e_specs:
name = e_spec.GetEdgeName()
csr = self.sg.GetEdgeIndexCSR(name)
ind_offset = np.squeeze(np.array(csr.GetIndPtr()), axis=-1)
n2_indices = np.squeeze(np.array(csr.GetIndices()), axis=-1)
edge_index = np.arange(len(n2_indices), dtype=ind_offset.dtype)
result.update(
{name: (ind_offset, n2_indices, edge_index, csr.row_num, csr.col_num)}
)
return result
def get_ego_edge_index(self, hops: int):
"""get_ego_edge_index, get ego-subgraph wrt. hops num
with the iteration of many GNNs, the receptive-field will decrease. For example, the first iteration of
a 3-layer GNN usually should be conducted on 3-hop neighborhood wrt. target nodes. The second iteration
only need 2-hop neighborhood of target nodes.
get_ego_edge_index will return a set of edge index wrt. hops.
Args:
hops: how many hops we need.
Returns: List of edge indexes. [{edge_name: edge_index_k_hop}, {edge_name: edge_index_{k-1}_hop} ...]
"""
res = self.sg.GetEgoEdgeIndex(hops)
res_size = len(res)
res_final = []
for i in range(res_size):
one_hop_dict = {}
for e_name, e_coo in res[i].items():
n1_indices = np.squeeze(np.array(e_coo.GetN1Indices()), axis=-1)
n2_indices = np.squeeze(np.array(e_coo.GetN2Indices()), axis=-1)
e_indices = np.squeeze(np.array(e_coo.GetEdgeIndex()), axis=-1)
one_hop_dict.update({e_name: (n1_indices, n2_indices, e_indices)})
res_final.insert(0, one_hop_dict)
return res_final
def get_node_dense_feature(self, node_name, f_name):
"""
Args:
node_name(str): node (type) name
f_name(str): dense feature name
Returns: 2-dim np.array(), node_num * dense_feature_dim
"""
res = self.sg.GetNodeDenseFeatureArray(node_name, f_name)
n_d_array = res.GetFeatureArray()
py_n_d_array = np.array(n_d_array)
return py_n_d_array
def get_node_sparse_kv_feature(self, node_name, f_name):
"""
Args:
node_name(str): node (type) name
f_name(str): spares kv feature name
Returns: (indices_offset, keys, values)
"""
res = self.sg.GetNodeSparseKVArray(node_name, f_name)
f_array = res.GetFeatureArray()
ind_offset = np.squeeze(np.array(f_array.GetIndOffset()), axis=-1)
keys = np.squeeze(np.array(f_array.GetKeys()), axis=-1)
values = np.squeeze(np.array(f_array.GetVals()), axis=-1)
return ind_offset, keys, values
def get_node_sparse_k_feature(self, node_name, f_name):
"""
Args:
node_name(str): node (type) name
f_name(str): spares key feature name
Returns: (indices_offset, keys)
"""
# todo zdl add check and raise error if not exists
res = self.sg.GetNodeSparseKArray(node_name, f_name)
f_array = res.GetFeatureArray()
ind_offset = np.squeeze(np.array(f_array.GetIndOffset()), axis=-1)
keys = np.squeeze(np.array(f_array.GetKeys()), axis=-1)
return ind_offset, keys
def get_edge_dense_feature(self, edge_name, f_name):
"""
Args:
edge_name(str): edge (type) name
f_name(str): edge dense feature name
Returns: 2-d np array. edge_num * dense_feature_dim
"""
res = self.sg.GetEdgeDenseFeatureArray(edge_name, f_name)
n_d_array = res.GetFeatureArray()
py_n_d_array = np.array(n_d_array)
return py_n_d_array
def get_edge_sparse_kv_feature(self, edge_name, f_name):
"""
Args:
edge_name(str): edge (type) name
f_name(str): edge sparse kv feature name
Returns: (indices_offset, keys, values)
"""
res = self.sg.GetEdgeSparseKVArray(edge_name, f_name)
f_array = res.GetFeatureArray()
ind_offset = np.squeeze(np.array(f_array.GetIndOffset()), axis=-1)
keys = np.squeeze(np.array(f_array.GetKeys()), axis=-1)
values = np.squeeze(np.array(f_array.GetVals()), axis=-1)
return ind_offset, keys, values
def get_edge_sparse_k_feature(self, edge_name, f_name):
"""
Args:
edge_name(str): edge (type) name
f_name(str): edge sparse key feature name
Returns: (indices_offset, keys)
"""
res = self.sg.GetEdgeSparseKArray(edge_name, f_name)
f_array = res.GetFeatureArray()
ind_offset = np.squeeze(np.array(f_array.GetIndOffset()), axis=-1)
keys = np.squeeze(np.array(f_array.GetKeys()), axis=-1)
return ind_offset, keys
def get_node_num_per_sample(self):
"""node num per sample
Returns: {n_name1: [XXX], n_name2: [XXX]}, node_num per sample (pb);
"""
return self.sg.GetNodeNumPerSample()
def get_edge_num_per_sample(self):
"""edge num per sample
Returns: {e_name1: [XXX], e_name2: [XXX]}, node_num per sample (pb);
"""
return self.sg.GetEdgeNumPerSample()
def get_root_index(self):
"""
Returns: root ids of each sample. {name: [[]]}. name -> sample_index -> root_ids
"""
return self.sg.GetRootIds()