python - 在 TensorFlow 上使用 Inception 时出错(所有图片的输出相同)

标签 python tensorflow deep-learning

我正在尝试在 cifar-10 数据集上训练网络,但我不想使用图片,而是想使用 Inceptions 的最后一层之前的特征。

所以我写了一些peace pf代码来传递Inception中的所有图片并获取特征,如下:

def run_inference_on_images(images):
 #Creates graph from saved GraphDef.
 create_graph()

 features_vec = np.ndarray(shape=(len(images),2048),dtype=np.float32)

 with tf.Session() as sess:
   # Some useful tensors:
   # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
   #   float description of the image.
   # 'DecodeJpeg:0': A numpy array of the image
   # Runs the softmax tensor by feeding the image data as input to the graph.
   length = len(images)
   for i in range(length):
       print ('inferencing image number',i,'out of', length)
       features_tensor = sess.graph.get_tensor_by_name('pool_3:0')
       features = sess.run(features_tensor,
                        {'DecodeJpeg:0': images[i]})
       features_vec[i] = np.squeeze(features)
 return features_vec

“images”是 CIFAR-10 数据集。这是一个形状为 (50000,32,32,3) 的 numpy 数组

我面临的问题是,即使我向“sess.run”部分提供不同的图片,“features”输出也始终相同。 我错过了什么吗?

最佳答案

我能够解决这个问题。看来 Inception 并不像我想象的那样使用 numPy 数组,因此我将数组转换为 JPEG 图片,然后才将其馈送到网络。

下面是有效的代码(其余相同):

def run_inference_on_images(images):
  # Creates graph from saved GraphDef.
  create_graph()

  features_vec = np.ndarray(shape=(len(images),2048),dtype=np.float32)
  with tf.Session() as sess:
    features_tensor = sess.graph.get_tensor_by_name('pool_3:0')
    length = len(images)
    for i in range(length):
        im = Image.fromarray(images[i],'RGB')
        im.save("tmp.jpeg")
        data = tf.gfile.FastGFile("tmp.jpeg", 'rb').read()
        print ('inferencing image number',i,'out of', length)
        features = sess.run(features_tensor,
                        {'DecodeJpeg/contents:0': data})
        features_vec[i] = np.squeeze(features)       
   return features_vec

关于python - 在 TensorFlow 上使用 Inception 时出错(所有图片的输出相同),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37723808/

相关文章:

Python - 带有 input() 的 If 语句在 def() 函数内部不起作用

python - 检查不存在的可选表

python - 了解 PyTorch 中的反向传播

tensorflow - 在其他平台(如 Linux)上构建 tensorflow lite

python - 由于判别器输出为负,Tensorflow GAN 判别器损失 NaN

python-3.x - 我们可以提取未训练过的类的 VGG16/19 特征吗

machine-learning - tensorflow 。从 BasicRNNCell 切换到 LSTMCell

python - 错误: Message: element click intercepted:

python - Sendgrid 使用 API key 进行身份验证

Tensorflow 不是确定性的,它应该在哪里