python - 如何使用strided_slice选择tensorflow中的所有元素?

标签 python tensorflow

我阅读了文档中的示例:

# 'input' is [[[1, 1, 1], [2, 2, 2]],
#             [[3, 3, 3], [4, 4, 4]],
#             [[5, 5, 5], [6, 6, 6]]]
tf.strided_slice(input, [1, 0, 0], [2, 1, 3], [1, 1, 1]) ==> [[[3, 3, 3]]]
tf.strided_slice(input, [1, 0, 0], [2, 2, 3], [1, 1, 1]) ==> [[[3, 3, 3],
                                                               [4, 4, 4]]]
tf.strided_slice(input, [1, -1, 0], [2, -3, 3], [1, -1, 1]) ==>[[[4, 4, 4],
                                                                 [3, 3, 3]]] 

看来我不能简单地使用 input[:,:] 来选择所有元素,而是必须使用 input[:-1, :- 1]。然而,这样 input[:-1, :-1] ,我会错过最后一行或最后一列。我该怎么办?

我举个例子:

ph = tf.placeholder(shape=[None, 3], dtype=tf.int32)
x = tf.strided_slice(ph, [0,0],[-1,-1],[1,1])
input_ = np.array([[1,2,3],
                  [3,4,5],
                  [7,8,9]])
sess = tf.InteractiveSession()
sess.run(x,feed_dict={ph:input_})

输出:

array([[1, 2],
       [3, 4]])

我读了很多资料,发现可以使用tf.shape(ph),让我们看看:

ph = tf.placeholder(shape=[None, 3], dtype=tf.int32)
x = tf.strided_slice(ph, [0,0],tf.shape(ph),[1,1])
input_ = np.array([[1,2,3],
                  [3,4,5],
                  [7,8,9]])
sess = tf.InteractiveSession()
sess.run(x,feed_dict={ph:input_})

输出:

array([[1, 2, 3],
       [3, 4, 5],
       [7, 8, 9]])

但是,如果我想得到这样的结果:

[[1, 2],
 [3, 4],
 [7, 8]]

我能做什么?

最佳答案

我无法理解您的问题,但我尝试回答一下:

您可以使用 x[:, :, :] 语法来选择数组的所有元素:

sess = tf.Session()
inp = tf.constant([[[1, 1, 1], [2, 2, 2]],
                   [[3, 3, 3], [4, 4, 4]],
                   [[5, 5, 5], [6, 6, 6]]])
print(inp.shape)

x = inp[:, :, :]
print(sess.run(x))

要获得您想要的最后输出,当然可以通过一些手动尺寸计算来实现:

sess = tf.Session()
x = tf.constant([[1,2,3],
                 [3,4,5],
                 [7,8,9]])
y = tf.shape(x)
bounds = tf.concat([y[:-1], [-1]], axis=0)
out = tf.strided_slice(x, [0,0], bounds, [1,1])
print(sess.run(out))

一般来说,Tensorflow 切片语法遵循 numpy 的切片语法,记录如下: https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html

希望有帮助!

关于python - 如何使用strided_slice选择tensorflow中的所有元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43827792/

相关文章:

python - 使用 NLTK 将分词器组合成语法和解析器

python - 检查是否设置了 argparse 可选参数

python - Tensorflow卷积和numpy卷积的区别

tensorflow - 配置中的类比训练的类多(tensorflow 对象检测 API)

python - 在 Keras 中使用通用句子编码器嵌入层

用于委托(delegate)的 Java 内部类的 Python 等价物

python - 在 Django 应用程序中使用 Python 对象

python - 不能将序列乘以 Float 类型的非整数? Python

graph - Tensorflow:如何将 .meta、.data 和 .index 模型文件转换为一个 graph.pb 文件

tensorflow - 将编译标志传递给 bazel (TensorFlow)