python - 稀疏张量的 while_loop 中的 InvalidArgumentError

标签 python tensorflow

我正在使用 while_loop 迭代更新矩阵。对于密集张量,循环运行良好,但是当我使用稀疏张量时,出现以下错误:

InvalidArgumentError: Number of rows of a_indices does not match number of entries in a_values [[Node: while/SparseTensorDenseMatMul/SparseTensorDenseMatMul = SparseTensorDenseMatMul[T=DT_FLOAT, Tindices=DT_INT64, adjoint_a=false, adjoint_b=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](while/SparseTensorDenseMatMul/SparseTensorDenseMatMul/Enter, while/SparseTensorDenseMatMul/SparseTensorDenseMatMul/Enter_1, ConstantFolding/dense_to_sparse/Shape_enter/_1, while/Switch_1:1)]]
[[Node: while/Exit_1/_5 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_62_while/Exit_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

我在两个版本之间唯一改变的是用 HH=tf.contrib.layers.dense_to_sparse(HH) 转换 HH 并使用 tf.sparse_tensor_dense_matmul(HH,f) 而不是 tf.matmul(HH,f) - 如下面的注释代码所示。

with tf.device('/gpu:0'):
    g=tf.constant(g,shape=[np.size(g),1],dtype=tf.float32)
    H=tf.constant(H,dtype=tf.float32);
    Ht=tf.transpose(H)
    HH=tf.matmul(Ht,H)
    #HH=tf.contrib.layers.dense_to_sparse(HH)
    a=tf.matmul(Ht,g)
    i=tf.constant(0,dtype=tf.int32)
    f=tf.constant(f,dtype=tf.float32)
    body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.matmul(HH,f)+10e-9))
    #body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.sparse_tensor_dense_matmul(HH,f)+10e-9))
    cond= lambda i,f:tf.less(i,iterations)
    i,f=tf.while_loop(cond,body,(i,f))
sess=tf.Session()
i,f=sess.run([i,f])

请注意,只要 H、g 和 f 足够小,此代码就可以工作。例如,此错误发生在 H.shape=(8000,3840) 、g.shape=(8000,1)、f.shape=(3840,1) 和更大的情况下,但对于 H.shape=(8000,第3584章 ,g.shape=(8000,1),f.shape=(3584,1)或更小我是否需要在 while 循环中对稀疏张量做一些特殊的事情以确保它们保持其形状?

最佳答案

我尝试从tensorflow 1.8更新到1.12,但tensorflow完全停止工作(ts.Session将无限期挂起)。因此,我破坏了 anaconda 环境,并从头开始使用 TensorFlow 1.12。更新/重新安装后,稀疏张量的问题消失了,但尚不清楚问题是否出在我的 anaconda 环境中的 tensorflow 版本或其他问题上。

关于python - 稀疏张量的 while_loop 中的 InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54444626/

相关文章:

python - 是否可以在 python 脚本中交叉引用 bash 和 python 变量

python - cronjob-script 创建的日志的位置

python - ld : library not found for -lboost_python on MacOS

python - Python 中的 SAML 2.0 服务提供商

python - Tensorflow:tf.assign 不分配任何东西

python - 值错误: Output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`

python - DCGAN : discriminator getting too strong too quickly to allow generator to learn

python - 循环excel文件并基于Python中的一个公共(public)列进行合并

python - 使用 MirroredStrategy : isinstance(x, dataset_ops.DatasetV2 时出现 AssertionError)

machine-learning - TensorFlow:多次评估测试集但得到不同的精度