python - tf.scatter_update 或 tf.scatter_nd_update 可以用于更新张量的列切片吗?

标签 python tensorflow

我想实现一个函数,它接受一个变量作为输入,改变它的一些行或列,并将它们替换回原始变量。我可以使用 tf.gather 和 tf.scatter_update 对行切片实现它,但无法对列切片执行此操作,因为显然 tf.scatter_update 仅更新行切片并且没有轴功能。我不是 tensorflow 专家,因此我可能会遗漏一些东西。有人可以帮忙吗?

def matrix_reg(t, percent_t, beta):
    
    ''' Takes a variable tensor t as input and regularizes some of its rows.
    The number of rows to be regularized are specified by the percent_t. Returns the original tensor by updating its rows indexed by row_ind.
    
    Arguments:
        t -- input tensor
        percent_t -- percentage of the total rows
        beta -- the regularization factor
    Output:
        the regularized tensor
        '''
    row_ind = np.random.choice(int(t.shape[0]), int(percent_t*int(t.shape[0])), replace = False)
    t_ = tf.gather(t,row_ind)
    t_reg = (1+beta)*t_-beta*(tf.matmul(tf.matmul(t_,tf.transpose(t_)),t_))
    return tf.scatter_update(t, row_ind, t_reg)

最佳答案

以下是如何更新行或列的小演示。这个想法是,您指定希望更新中每个元素结束的变量的行索引和列索引。使用 tf.meshgrid 很容易做到这一点.

import tensorflow as tf

var = tf.get_variable('var', [4, 3], tf.float32, initializer=tf.zeros_initializer())
updates = tf.placeholder(tf.float32, [None, None])
indices = tf.placeholder(tf.int32, [None])
# Update rows
var_update_rows = tf.scatter_update(var, indices, updates)
# Update columns
col_indices_nd = tf.stack(tf.meshgrid(tf.range(tf.shape(var)[0]), indices, indexing='ij'), axis=-1)
var_update_cols = tf.scatter_nd_update(var, col_indices_nd, updates)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('Rows updated:')
    print(sess.run(var_update_rows, feed_dict={updates: [[1, 2, 3], [4, 5, 6]], indices: [3, 1]}))
    print('Columns updated:')
    print(sess.run(var_update_cols, feed_dict={updates: [[1, 5], [2, 6], [3, 7], [4, 8]], indices: [0, 2]}))

输出:

Rows updated:
[[0. 0. 0.]
 [4. 5. 6.]
 [0. 0. 0.]
 [1. 2. 3.]]
Columns updated:
[[1. 0. 5.]
 [2. 5. 6.]
 [3. 0. 7.]
 [4. 2. 8.]]

关于python - tf.scatter_update 或 tf.scatter_nd_update 可以用于更新张量的列切片吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52872239/

相关文章:

c++ - 在 Python 中训练神经网络并在 C++ 中部署

python - tensorflow 中的自定义卷积函数

python - pycharm无法运行脚本但可以调试它

python - 使用 python 运行 bash 命令并获取输出

python - 如何使用 QTableWidget PageUp/PageDown 表格?

python - 使用 Beautifulsoup 的类的正则表达式

python-3.x - Tensorflow 合并数据集的替代方法

tensorflow - 将 tensorflow 数据集记录分成多条记录

python - 保留一列的首选值并删除不太首选的值

python - 动态更新 pydantic 中子类所有属性的类型提示