我正在使用 Tensorflow 进行图像分类。我使用 image_retraining/retrain.py 使用新类别重新训练 inception 库,并使用它使用 https://github.com/llSourcell/tensorflow_image_classifier/blob/master/src/label_image.py 中的 label_image.py 对图像进行分类。如下:
import tensorflow as tf
import sys
# change this as you see fit
image_path = sys.argv[1]
# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
in tf.gfile.GFile("/root/tf_files/output_labels.txt")]
# Unpersists graph from file
with tf.gfile.FastGFile("/root/tf_files/output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
# Feed the image_data as input to the graph and get first prediction
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})
predictions = sess.run(softmax_tensor,{'DecodePng/contents:0': image_data})
# Sort to show labels of first prediction in order of confidence
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
for node_id in top_k:
human_string = label_lines[node_id]
score = predictions[0][node_id]
print('%s (score = %.5f)' % (human_string, score))
我注意到两个问题。当我使用新类别重新训练时,它只训练 JPG 图像。我是机器学习领域的菜鸟,所以不确定这是否是一个限制,或者是否可以训练其他扩展图像,如 PNG、GIF?
另一种情况是在对图像进行分类时,输入再次仅适用于 JPG。我尝试将上面 label_image.py 中的 DecodeJpeg 更改为 DecodePng 但无法工作。我尝试的另一种方法是将其他格式转换为 JPG,然后再将它们传递进行分类,例如:
im = Image.open('/root/Desktop/200_s.gif').convert('RGB')
im.save('/root/Desktop/test.jpg', "JPEG")
image_path1 = '/root/Desktop/test.jpg'
还有其他方法可以做到这一点吗? Tensorflow是否有处理除JPG之外的其他图像格式的函数?
按照 @mrry 的建议,我通过输入解析的图像与 JPEG 进行比较,尝试了以下操作
import tensorflow as tf
import sys
import numpy as np
from PIL import Image
# change this as you see fit
image_path = sys.argv[1]
# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
image = Image.open(image_path)
image_array = np.array(image)[:,:,0:3] # Select RGB channels only.
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
in tf.gfile.GFile("/root/tf_files/output_labels.txt")]
# Unpersists graph from file
with tf.gfile.FastGFile("/root/tf_files/output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
# Feed the image_data as input to the graph and get first prediction
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
predictions = sess.run(softmax_tensor,{'DecodeJpeg:0': image_array})
# Sort to show labels of first prediction in order of confidence
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
for node_id in top_k:
human_string = label_lines[node_id]
score = predictions[0][node_id]
print('%s (score = %.5f)' % (human_string, score))
它适用于 JPEG 图像,但当我使用 PNG 或 GIF 时它会抛出
Traceback (most recent call last):
File "label_image.py", line 17, in <module>
image_array = np.array(image)[:,:,0:3] # Select RGB channels only.
IndexError: too many indices for array
最佳答案
该模型只能训练(和评估)JPEG 图像,因为您保存在 /root/tf_files/output_graph.pb
中的 GraphDef
仅包含tf.image.decode_jpeg()
op,并使用该 op 的输出进行预测。至少有几个使用其他图像格式的选项:
输入已解析的图像而不是 JPEG 数据。在当前程序中,您将 JPEG 编码图像作为张量
"DecodeJpeg/contents 的字符串值输入: 0”
。相反,您可以输入张量"DecodeJpeg:0"
的已解码图像数据的 3-D 数组(它表示tf.image 的输出) .decode_jpeg()
op),您可以使用 NumPy、PIL 或其他一些 Python 库来创建此数组。重新映射
tf.import_graph_def()
中的图像输入.tf.import_graph_def()
函数使您能够通过重新映射各个张量值将两个不同的图连接在一起。例如,您可以执行如下操作来向现有图形添加新的图像处理操作:image_string_input = tf.placeholder(tf.string) image_decoded = tf.image.decode_png(image_string_input) # Unpersists graph from file with tf.gfile.FastGFile("/root/tf_files/output_graph.pb", 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) softmax_tensor, = tf.import_graph_def( graph_def, input_map={"DecodeJpeg:0": image_decoded}, return_operations=["final_result:0"]) with tf.Session() as sess: # Feed the image_data as input to the graph and get first prediction predictions = sess.run(softmax_tensor, {image_string_input: image_data}) # ...
关于python - PNG、GIF 等的 Tensorflow Label_Image,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41459900/