Tensorflow:在未经验证的形状变量上应用 bool 掩码?

标签 tensorflow

好的,下面的场景:我有一个变量 var,它的等级是固定的,但它的形状不是。例如,它可以是任意长度的一维张量。我想在与我的图形的 session 开始时初始化一次 var。我使用附加到此变量的占位符来执行此操作(另请参见下面的代码)。然后我在我的图中做了一些计算,在某些时候我需要从 var 中提取所有大于 0 的值,就像这样:

import tensorflow as tf    

init_var = tf.placeholder(dtype=tf.float64, shape=[None])
var = tf.Variable(init_var,dtype=tf.float64,validate_shape=False)
booled = tf.boolean_mask(var, var>0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer(), { init_var: [1,-2,3] } )
    print sess.run([booled])

但这会产生 ValueError-Exception:

ValueError: Number of mask dimensions must be specified, 
even if some dimensions are None. 
E.g. shape=[None] is ok, but shape=None is not.

现在,如果我将 validate_shape 设置为 True,此异常就会消失,但我需要在构建图形时修复 var 的形状,但我希望它是动态的。尽管如此,如果有人知道如何评估未验证形状变量的 bool 掩码或如何在每个 session 中重新初始化 var 的形状(可能无需重建整个图形),我将非常感激。

最佳答案

好的,我同时解决了这个问题,事实证明这个解决方案难以置信简单。虽然在定义变量时似乎不可能指定带有“无”条目的形状(因此只指定其等级),但可以在 var.set_shape() 之后立即执行此操作像这样:

import tensorflow as tf    

init_var = tf.placeholder(dtype=tf.float64, shape=[None])
var = tf.Variable(init_var,dtype=tf.float64,validate_shape=False)
var.set_shape([None])
booled = tf.boolean_mask(var, var>0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer(), { init_var: [1,-2,3] } )
    print sess.run([booled])

现在它完全符合我的期望!

关于Tensorflow:在未经验证的形状变量上应用 bool 掩码?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47064284/

相关文章:

python - 加载的 .pb 文件的 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 为空

python - 尚未在model.summary()上建立此模型错误

python - 在 K 均值聚类中组织聚类

python - 无法在 anaconda tenserenv 中运行 convolutional.py

python - Mac OS High Sierra : Tensorflow verions returned by `pip3 upgrade ` and `python3 -c ' import tensorflow as tf; print(tf. __version__ )'` 不同

tensorflow - 如何在 TF-Slim 的 eval_image_classifier.py 中获取错误分类的文件?

python - 如何将 tensorflow 模型部署到azure ml工作台

python - 使用中间层作为输入和输出的 keras 模型

python - 属性错误 : 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'decode'

python - 带有空张量的tensorflow scatter_nd?