python - 处理类中的 tensorflow session

标签 python machine-learning tensorflow deep-learning

我正在使用 tensorflow 来预测神经网络的输出。我有一个类,我在其中描述了神经网络,并且有一个主文件,其中进行预测,并根据结果更新权重。然而,预测似乎真的很慢。我的代码如下所示:

class NNPredictor():
    def __init__(self):
        self.input = tf.placeholder(...)
        ...
        self.output = (...) #Neural network output
    def predict_output(self, sess, input):
        return sess.run(tf.squeeze(self.output), feed_dict = {self.input: input})

主文件如下所示:

sess = tf.Session()
predictor = NNPredictor()

input = #some initial value 
for i in range(iter):
    output = predictor.predict_output(sess, input)
    input = #some function of output

但是,如果我在类中使用以下函数定义:

    def predict_output(self):
        return self.output

主文件如下:

sess = tf.Session()
predictor = NNPredictor()

input = #some initial value 
output_op = predictor.predict_value()
for i in range(iter):
    output = np.squeeze(sess.run(output_op, feed_dict = {predictor.input: input}))
    input = #some function of output

代码的运行速度几乎快了 20-30 倍。我似乎不明白这里的事情是如何运作的,我想知道最佳实践是什么。

最佳答案

这与 Python 屏蔽的底层内存访问有关。下面是一些示例代码来说明这个想法:

import time

runs = 10000000

class A:
    def __init__(self):
    self.val = 1

    def get_val(self):
    return self.val

# Using method to then call object attribute
obj = A()
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.get_val()
end = time.time()
print end - start

# Using object attribute directly
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.val
end = time.time()
print end - start

# Assign to local_var first
start = time.time()
total = 0
local_var = obj.get_val()
for i in xrange(runs):
    total += local_var
end = time.time()
print end - start

在我的计算机上,它按以下时间运行:

1.49576115608
0.656110048294
0.551875114441

具体到您的情况,您在第一种情况下调用对象方法,但在第二种情况下不这样做。如果您以这种方式多次调用代码,则会出现显着的性能差异。

关于python - 处理类中的 tensorflow session ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44896433/

相关文章:

python - 嵌套循环的字典理解没有按计划工作

python - 更改tensorflow object_detection教程中的结果数量

python - 将颜色条添加到集群热图

使用 ctypes 指向压缩结构中的 uint_64 的 Python 指针地址?

python - 多对多字段和 request.user 不起作用

python - Google ML Engine preprocessor_pb2 ImportError

python - 使用 Mahalanobis 距离进行多元异常值去除

java - 使用斯坦福解析器解析凌乱的文本

tensorflow - 使用 Tensorboard 在一张图中绘制多个图

Tensorflow 2.1.0 - 函数构建代码之外的操作正在传递一个 "Graph"张量