python - 用于图像特征提取的 tensorflow 多重处理

标签 python multithreading tensorflow

我有一些基本函数,可以接收图像的 URL 并通过 VGG-16 CNN 对其进行转换:

def convert_url(_id, url):   
  im = get_image(url)
  return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}))

我有一大组 URL(约 60,000 个),我想在其中执行此功能。每次迭代都需要超过一秒的时间,这太慢了。我想通过并行使用多个进程来加快速度。无需担心共享状态,因此多线程的常见陷阱不是问题。

但是,我不太确定如何真正让 tensorflow 与多处理包一起使用。我知道您无法将 tensorflow session 传递给 Pool 变量。因此,我尝试初始化 session 的多个实例:

def init():
  global sess;
  sess = tf.Session()

但是当我实际启动该进程时,它只是无限期地挂起:

with Pool(processes=3,initializer=init) as pool:
  results = pool.starmap(convert_url, list(id_img_dict.items())[0:5])

请注意, tensorflow 图是全局定义的。我认为这是正确的方法,但我不确定:

input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)

arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
  _, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)

有人可以帮我让它工作吗?非常感谢。

最佳答案

忘记 python 的普通多线程工具并使用 tensorflow.contrib.data.Dataset 。尝试如下操作。

urls = ['img1.jpg', 'img2.jpg', ...]
batch_size = 16
n_batches = len(urls) // batch_size  # do something more elegant for remainder


def load_img(url):
    image = tf.read_file(url, name='image_data')
    image = tf.image.decode_jpeg(image, channels=3, name='image')
    return image


def preprocess(img_tensor):
    img_tensor = (tf.cast(img_tensor, tf.float32) / 255 - 0.5)*2
    img_tensor.set_shape((256, 256, 3))  # whatever shape
    return img_tensor


dataset = tf.contrib.data.Dataset.from_tensor_slices(urls)
dataset = dataset.map(load_img).map(preprocess)

preprocessed_images = dataset.batch(
    batch_size).make_one_shot_iterator().get_next()


arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
    _, end_points = vgg_16(preprocessed_images, is_training=False)
    output = end_points['vgg_16/fc7']


results = []

with tf.Session() as sess:
    tf.train.Saver().restore(sess, checkpoint_file)
    for i in range(n_batches):
        batch_results = sess.run(output)
        results.extend(batch_results)
        print('Done batch %d / %d' % (i+1, n_batches))

关于python - 用于图像特征提取的 tensorflow 多重处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45909738/

相关文章:

python - 函数中的矩阵 View 没有副作用

Python如何在使用one-hot-encode/pd.get_dummies后反转实际值

c++ - 线程间并发无锁?

performance - 在 Jmeter 中减速的最佳方法是什么?

python - 在全局上下文中使用一个 GradientTape

python - 根据灰度动态范围将图像分类为褪色或未褪色?

python - 如何在 Pandas 中连接可变数量的列

.net - 是否保证多个发送线程的 WCF TCP 消息顺序正确?

tensorflow - Tensorflow 中的平衡准确度分数

tensorflow - 新版本的TensorFlow中的tf.nn.rnn等价于什么?