python - 使用随机生成的整数作为tensorflow map_fn中的张量索引

标签 python numpy tensorflow keras

我在 Keras 的自定义层中初始化张量 A ,其中 batchSize 是占位符:

A = K.zeros([batchSize, 2, 2 ,2])

我还初始化了一个大小为 [3,2,2,2] 的 numpy 数组 B。我想从 B 中随机选择 [i,2,2,2] 大小的数组,其中 i = 0,1,2,并且将其分配给 A 的第一个维度并重复此batchSize 次数。

由于我无法显式循环batchSize,因此我尝试了tensorflow.map_fn,如下所示:

ANew = tf.map_fn(lambda x: K.variable(B[np.random.randint(0,3,size=(1)).tolist()[0],:,:,:],
                 A, dtype=’float’, back_prop=False, infer_shape=True)

这会生成ANew张量。然而,看起来 np.random.randint 只被调用一次;结果,我总是选择相同的索引。如何修改代码以便 np.random.randint(0,3,size=(1)).tolist()[0] 被调用batchSize 次数?

最佳答案

您正在寻找K.gather

A = K.gather(B, indices_list)

关于python - 使用随机生成的整数作为tensorflow map_fn中的张量索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49926256/

相关文章:

python - python 样式结构的 BNF 语法

python - 创建 Django 管理中级页面

PYTHONPATH 干扰 virtualenv

python - 如何在python中找到大于均值的列表的最长连续子序列

tensorflow - 了解 Keras 模型架构(张量索引)

使用 virtualenv 进行 Python 部署(在无法访问互联网的服务器上)

python - 提高h5py的阅读速度

python - 堆叠 numpy 数组?

python - 从命令行强制 TensorFlow-GPU 使用 CPU

python - 训练Tensorflow时出现ValueError : setting an array element with a sequence