python - Tensorflow 中的单图像推理 [Python]

标签 python tensorflow

我已经将预训练的 .ckpt 文件转换为 .pb 文件,卡住模型并保存权重。我现在要做的是使用该 .pb 文件进行简单推理,然后提取并保存输出图像。该模型是从此处下载的(用于语义分割的全卷积网络):https://github.com/MarvinTeichmann/KittiSeg .到目前为止,我已经设法加载图像,设置默认 tf 图并导入模型定义的图,读取输入和输出张量并运行 session (此处出错)。

import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
from PIL import Image

# Read the image & get statstics
img=Image.open('/path-to-image/demoImage.png')
img.show()
width, height = img.size
print(width)
print(height)

#Plot the image
#image.show()

with tf.Graph().as_default() as graph:

        with tf.Session() as sess:

                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-FCN-model/FCN8.pb",'rb') as f:

                                #Set default graph as current graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                #sess.graph.as_default() #new line

                                # Import a graph_def into the current default Graph
                                tf.import_graph_def(graph_def, name='')

                                # Print the name of operations in the session
                                #for op in sess.graph.get_operations():

                                    #print "Operation Name :",op.name            # Operation name
                                    #print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Placeholder:0')
                                l_output = graph.get_tensor_by_name('save/Assign_38:0')

                                print "l_input", l_input
                                print "l_output", l_output
                                print
                                print

                                # Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.                              
                                result = sess.run(l_output, feed_dict={l_input : img})
                                print(results)

                                print("Inference done")

                                # Info
                                # First Tensor name : Placeholder:0
                                # Last tensor name  : save/Assign_38:0"

错误是否来自图像格式(例如,我应该将 .png 转换为另一种格式吗?)。这是另一个根本性错误吗?

最佳答案

我设法修复了这个错误,下面是在完全卷积网络上推断单个图像的工作脚本(对于任何对 SEGNET 的替代分割算法感兴趣的人)。该模型使用双线性插值进行缩放而不是非池化层。无论如何,因为模型可以以 .chkpt 格式下载,所以您必须先卡住模型并将其另存为 .pb 文件。稍后,您必须从 TF 优化器传递网络以将 Dropout 概率设置为 1。之后,在此脚本中设置正确的输入和输出张量名称,推理工作正常,提取分割图像。

import tensorflow as tf # Default graph is initialized when the library is imported
import os
from tensorflow.python.platform import gfile
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import cv2

with tf.Graph().as_default() as graph: # Set default graph as graph

           with tf.Session() as sess:
                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-protobuf/FCN8_Freezed.pb",'rb') as f:

                                print("Load Image...")
                                # Read the image & get statstics
                                image = scipy.misc.imread('/Path-To-Image/uu_000010.png')
                                image = image.astype(float)
                                Input_image_shape=image.shape
                                height,width,channels = Input_image_shape

                                print("Plot image...")
                                #scipy.misc.imshow(image)

                                # Set FCN graph to the default graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                sess.graph.as_default()

                                # Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)

                                tf.import_graph_def(
                                graph_def,
                                input_map=None,
                                return_elements=None,
                                name="",
                                op_dict=None,
                                producer_op_list=None
                                )

                                # Print the name of operations in the session
                                for op in graph.get_operations():
                                        print "Operation Name :",op.name         # Operation name
                                        print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
                                l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor

                                print "Shape of input : ", tf.shape(l_input)
                                #initialize_all_variables
                                tf.global_variables_initializer()

                                # Run Kitty model on single image
                                Session_out = sess.run( l_output, feed_dict = {l_input : image} 

关于python - Tensorflow 中的单图像推理 [Python],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45697823/

相关文章:

python - 如何在 tensorflow 中交替训练操作?

python - 如何在 Odoo 中创建新附件?它们是如何工作的?

python - 不明白为什么这段代码会引发 EOF 错误

python - undefined symbol : clock_gettime with tensorflow on ubuntu14. 04

python - tf.nn.conv2d vs tf.layers.conv2d

python - 更改 TensorFlow Session 的默认配置?

python - Pandas 会影响 Rapidfuzz 匹配的结果吗?

python - 我可以使用 Node.js 作为后端并使用 Python 来进行 AI 计算吗?

python - Easy Python Function——为什么我会得到这么多错误?

python - 在 Tensorflow 中将一组常量(一维数组)与一组矩阵(三维数组)相乘