Conversation
ge/models/line.py
Outdated
| return self._embeddings | ||
|
|
||
| def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1): | ||
| def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1,workers=tf.data.experimental.AUTOTUNE,use_multiprocessing=True): |
There was a problem hiding this comment.
tf.data.experimental.AUTOTUNE可以让程序自动的选择最优的线程并行个数
There was a problem hiding this comment.
当然用户也可以自己选择workers的数量,这里就是做为默认的设定
|
|
||
|
|
||
|
|
||
| """ |
There was a problem hiding this comment.
修改的时候直接复制进来,给替换掉了。。。
ge/models/node2vec.py
Outdated
| class Node2Vec: | ||
|
|
||
| def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs): | ||
| def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0,threads=1): |
There was a problem hiding this comment.
def init(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):部分的参数移动到train的部分了,use_rejection_sampling 这个木有实现
There was a problem hiding this comment.
use_rejection_sampling 如果需要增加这个的numba实现我可以写一下
ge/models/node2vec.py
Outdated
| self._embeddings = {} | ||
| for word in self.graph.nodes(): | ||
| self._embeddings[word] = self.w2v_model.wv[word] | ||
| for word in self.node_dict.keys(): |
There was a problem hiding this comment.
为什么用self.node_dict替换self.graph?
There was a problem hiding this comment.
csrgraph是以scipy形式存储图的,所以节点的名字变成了0,1,2,3.。。。这样的形式,node_dict是networkx和csrgraph之间的节点名字的对应关系,比如原来节点叫“XXX”可能对应的是新的节点名是1这样
shenweichen
left a comment
There was a problem hiding this comment.
ge/models/deepwalk.py 这个文件被你删除了。。
另外看下其他文件的一些修改我有些疑问,麻烦看下
wangbingnan136
left a comment
There was a problem hiding this comment.
因为node2vec的接口已经实现了deepwalk了,所以就把原来deep walk去掉了,当p和q都为1的时候,csrgraph内部会自动选择deepwalk对应的优化游走策略
|
|
||
|
|
||
|
|
||
| """ |
There was a problem hiding this comment.
修改的时候直接复制进来,给替换掉了。。。
ge/models/node2vec.py
Outdated
| class Node2Vec: | ||
|
|
||
| def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs): | ||
| def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0,threads=1): |
There was a problem hiding this comment.
def init(self, graph, walk_length, num_walks, p=1.0, q=1.0, workers=1, use_rejection_sampling=0):部分的参数移动到train的部分了,use_rejection_sampling 这个木有实现
ge/models/node2vec.py
Outdated
| class Node2Vec: | ||
|
|
||
| def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs): | ||
| def __init__(self, graph, walk_length, num_walks, p=1.0, q=1.0,threads=1): |
There was a problem hiding this comment.
use_rejection_sampling 如果需要增加这个的numba实现我可以写一下
ge/models/node2vec.py
Outdated
| self._embeddings = {} | ||
| for word in self.graph.nodes(): | ||
| self._embeddings[word] = self.w2v_model.wv[word] | ||
| for word in self.node_dict.keys(): |
There was a problem hiding this comment.
csrgraph是以scipy形式存储图的,所以节点的名字变成了0,1,2,3.。。。这样的形式,node_dict是networkx和csrgraph之间的节点名字的对应关系,比如原来节点叫“XXX”可能对应的是新的节点名是1这样
ge/models/line.py
Outdated
| return self._embeddings | ||
|
|
||
| def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1): | ||
| def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1,workers=tf.data.experimental.AUTOTUNE,use_multiprocessing=True): |
There was a problem hiding this comment.
当然用户也可以自己选择workers的数量,这里就是做为默认的设定
|
deepwalk去掉的话会让用户有困惑的。建议保留deepwalk的接口,底层可以调用node2vec |
.