python - 在 TensorFlow 中使用循环填充矩阵

标签 python numpy matrix machine-learning tensorflow

因此,我尝试在 TensorFlow 中填充一个矩阵,该矩阵的大小会根据输入而变化,因此我使用 TensorArray 来执行此操作。从本质上讲,Numpy 的等价物是:

areas = np.zeros((len(rows)-1,len(cols)-1))
for r in range(len(rows)-1):
    for c in range(len(cols)-1):
        areas[r,c] = (rows[r+1]-rows[r])*(cols[c+1]-cols[c])

我尝试使用 tf.while_looptf.TensorArray 在 TensorFlow 中实现它:

i = tf.constant(0)
areas = tf.TensorArray(dtype='float32', size=length_rc-1)
while_condition = lambda i, rows, areas: tf.less(i, length_rc-1)
def row_loop(i, rows, areas):        
    j = tf.constant(0)
    area = tf.TensorArray(dtype='float32', size=length_rc-1)
    while_condition = lambda j, cols, area: tf.less(j, length_rc-1)

    def col_loop(j, cols, area):
        area = area.write(j, tf.multiply(tf.subtract(rows[i+1],rows[i]),tf.subtract(cols[j+1],cols[j])))
        return [tf.add(j,1), cols, area]

    r = tf.while_loop(while_condition, col_loop, [j, cols, areas])
    areas = areas.write(i, r[2].stack())
    return [tf.add(i, 1), rows, areas]

# do the loop:
r = tf.while_loop(while_condition, row_loop, [i, rows, areas])
areas = r[2].stack()

p = sess.run([areas], feed_dict={pred_batch: pred, gt_batch: gt})

但是,它似乎不起作用,我也不确定为什么。如您所见,我的代码类似于这篇文章: Howe TensorArray and while_loop work together in tensorflow?

但是好像不行,谁知道是什么问题?我得到的具体错误是:

ValueError: Inconsistent shapes: saw (?,) but expected () (and infer_shape=True)

最佳答案

什么不起作用?您期望发生什么与实际发生什么?

一方面,在这两种情况下,您的循环条件看起来都差了 1。在第一种情况下,您将错过最后一行和最后一列,因为 range 只产生小于其参数的值。

同样,在第二种情况下,你的条件是tf.less(i, length_rc-1):你可能希望i等于length_rc -1 在最后一次迭代中,不小于它。条件应该是 tf.less(i, length_rc)

关于python - 在 TensorFlow 中使用循环填充矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44790538/

相关文章:

python - 找出所有能整除一个数的数

python - Plotly:当点超过 20 时,散点图标记消失

numpy - 如何在没有循环的情况下为多个整数类型在 numpy 中创建 bool 索引

python - 使用一维列索引数组切片并填充二维数组

c# - object[,] 在 C# 中是什么意思?

matlab - 获取矩阵子集的边界单元格的索引。软件

python - 如何绘制非数值数据的日期时间和 value_counts() ?

python - 如何合并 pyspark 和 pandas 数据框

python - 在 python/numpy 中随机放置给定数字的矩阵

python : Cannot save plots as png