python - 如何正确使用 tf.nn.max_pool_with_argmax

标签 python tensorflow

目前,我稍微玩了一下 tensorflow,以便更好地理解机器学习和 tensorflow 本身。因此,我想将 tensorflow 的方法(尽可能多地)可视化。为了可视化 max_pool,我加载了一张图像并执行了该方法。之后我同时显示:输入和输出图像。

import tensorflow as tf
import cv2
import numpy as np

import matplotlib.pyplot as plt

image = cv2.imread('lena.png')
image_tensor = tf.expand_dims(tf.Variable(image, dtype=tf.float32), 0)

#output, argmax = tf.nn.max_pool_with_argmax(image_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')
output = tf.nn.max_pool(image_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')

init = tf.initialize_all_variables()
session = tf.Session()
session.run(init)

output = session.run(output)

session.close()

image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
plt.figure()
plt.imshow(image)
plt.show()

output = cv2.cvtColor(output[0], cv2.COLOR_RGB2BGR)
plt.figure()
plt.imshow(255-output)
plt.show() 

一切正常,我得到了这个输出(如预期的那样)

image (input) enter image description here

现在我想测试方法 tf.nn.max_pool_with_argmax 来获取池化操作的 argmax。但是如果我取消注释该行

输出,argmax = tf.nn.max_pool_with_argmax(image_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name= '池1')

Python 崩溃

tensorflow.python.framework.errors.InvalidArgumentError: No OpKernel was registered to support Op 'MaxPoolWithArgmax' with these attrs [[Node: pool1 = MaxPoolWithArgmaxT=DT_FLOAT, Targmax=DT_INT64, ksize=[1, 2, 2, 1], padding="SAME", strides=[1, 2, 2, 1]]]

我不知道哪个参数是错误的,因为每个参数都应该是正确的 ( tensorflow docs ) ...

谁知道哪里出了问题?

最佳答案

从看the implementation , 看来 tf.nn.max_pool_with_argmax()仅针对 GPU 实现。如果您正在运行 TensorFlow 的纯 CPU 构建,那么您将收到格式为 “No OpKernel was registered to support Op 'MaxPoolWithArgmax' with these attrs ...” 的错误。

(这似乎是可以改进文档和错误消息的地方。)

关于python - 如何正确使用 tf.nn.max_pool_with_argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39493229/

相关文章:

tensorflow - 为什么我要选择与我的指标不同的损失函数?

tensorflow - 如何在 Tensorflow.js 中改变张量的值?

python-2.7 - Tensorflow saver.restore() 不恢复网络

java - spring boot API - 文档处理并在文档上并行执行 python 脚本

python - 迭代 df 行并附加到不带名称和 dtype 的列表

python - 如何在 Linux 中用 Python 捕获 gtk ApplicationWindow 的图像?

tensorflow - 如何在 tensorflow 中读取 utf-8 编码的二进制字符串?

python - Tensorflow 仅显示 "successfully opened CUDA library libcublas.so.10.0 locally",与 cudnn 无关

javascript - 在 Firefox 中通过 Selenium 访问 ShadowRoot 返回 JavascriptException : Cyclic object value

python - 为什么此环视正则表达式返回意外结果?