下面是来自 Tensorflow 网站的简单 mnist 教程(即单层 softmax),我尝试用多线程训练步骤对其进行扩展:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import threading
# Training loop executed in each thread
def training_func():
while True:
batch = mnist.train.next_batch(100)
global_step_val,_ = sess.run([global_step, train_step], feed_dict={x: batch[0], y_: batch[1]})
print("global step: %d" % global_step_val)
if global_step_val >= 4000:
break
# create session and graph
sess = tf.Session()
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
global_step = tf.Variable(0, name="global_step")
y = tf.matmul(x,W) + b
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
inc = global_step.assign_add(1)
with tf.control_dependencies([inc]):
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# initialize graph and create mnist loader
sess.run(tf.global_variables_initializer())
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# create workers and execute threads
workers = []
for _ in range(8):
t = threading.Thread(target=training_func)
t.start()
workers.append(t)
for t in workers:
t.join()
# evaluate accuracy of the model
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels},
session=sess))
我一定遗漏了一些东西,因为如下所示的 8 个线程会产生不一致的结果(准确度约为 0.1),而使用 1 个线程时只能获得预期的准确度(约 0.92)。有人知道我的错误吗?谢谢!
最佳答案
请注意,不幸的是,由于 GIL,threading
with python 不会创建真正的并行性。 .所以这里发生的是你将有多个线程都在同一个 CPU 上运行,实际上它们是顺序运行的。因此,我建议在 Tensorflow 中使用 Coordinator。可以在此处找到有关协调器的更多信息:
https://www.tensorflow.org/programmers_guide/threading_and_queues
https://www.tensorflow.org/programmers_guide/reading_data
最后,我建议你说:
with tf.device('/cpu:0'):
your code should go here... 'for the first thread'
然后为另一个线程使用另一个 cpu 等等...... 希望这个答案对您有用!!
关于python - TensorFlow 和线程,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41697662/