python - 分配 Keras 张量的索引条目

标签 python keras

我是 Keras 的初学者,所以我提前为任何普遍的理解不足表示歉意。

我想根据存储在另一个张量中的索引手动设置我的 Keras 张量的一些值。我相信我了解如何使用 tf.gather_nd 访问张量的条目(我在下面未经测试的尝试),而且我想我知道我只能设置变量的值而不是张量。

为了清楚起见,这发生在 GAN 的生成和识别阶段之间。

gen_out = generator(inputs)

indices_to_reset = Input(shape=(1,),dtype='int32')
new_values = Input(shape=(1,), dtype='int32')

batch_size = K.shape(x)[0]

idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))

idx = K.stack((idx_0, indices_to_reset), axis=0)

grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)

# Doesn't work
# gen_out[:,indices_to_reset] = new_values

updated_gen_out = ???

最佳答案

如果将所有内容都转换为 one_hot 张量并使用 switch 会容易得多:

(记得把所有的操作都放在lambda层里面,不然你会出问题的)

def replace_values(x):
    outs, indices, values = x

    #this is due to a strange bug between lambda and integers....
    indices = K.cast(indices, 'int32')


    #create one_hot indices
    one_hot_indices = K.one_hot(indices, size) #size is the size of gen_out
    one_hot_indices = K.batch_flatten(one_hot_indices)

    #have the desired values at their correct positions
    values_to_use = one_hot_indices * new_values


    #if values are 0, use gen_out, else use values
    return K.switch(K.equal(values_to_use, 0), outs, values_to_use)


updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])

Warning: new_values cannot be integer, they must be the same type as gen_out.


虚拟示例:

import numpy as np
from keras.layers import *
from keras.models import Model

size = 5
batch_size = 15

gen_out = Input((size,))
indices_to_reset = Input((1,), dtype='int32')
new_values = Input((1,))

def replace_values(x):
    outs, indices, values = x
    print(K.int_shape(outs))
    print(K.int_shape(indices))

    #this is due to a strange bug between lambda and integers....
    indices = K.cast(indices, 'int32')
    one_hot_indices = K.one_hot(indices, size)
    print(K.int_shape(one_hot_indices))
    one_hot_indices = K.batch_flatten(one_hot_indices)
    print(K.int_shape(one_hot_indices))

    values_to_use = one_hot_indices * new_values
    print(K.int_shape(values_to_use))

    return K.switch(K.equal(values_to_use, 0), outs, values_to_use)

updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])

model = Model([gen_out,indices_to_reset,new_values], updated_gen_out)

gen_outs = np.arange(batch_size * size).reshape((batch_size, size))
indices = np.concatenate([np.arange(5)]*3, axis=0)
new_vals = np.arange(15).reshape((15,1))

print('\n\ngen outs')
print(gen_outs)

print('\n\nindices')
print(indices)

print('\n\nvalues')
print(new_vals)

print('\n\n results')
print(model.predict([gen_outs, indices, new_vals]))

输出:

(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)


gen outs
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]
 [25 26 27 28 29]
 [30 31 32 33 34]
 [35 36 37 38 39]
 [40 41 42 43 44]
 [45 46 47 48 49]
 [50 51 52 53 54]
 [55 56 57 58 59]
 [60 61 62 63 64]
 [65 66 67 68 69]
 [70 71 72 73 74]]


indices
[0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]


values
[[ 0]
 [ 1]
 [ 2]
 [ 3]
 [ 4]
 [ 5]
 [ 6]
 [ 7]
 [ 8]
 [ 9]
 [10]
 [11]
 [12]
 [13]
 [14]]


 results
[[ 0.  1.  2.  3.  4.]
 [ 5.  1.  7.  8.  9.]
 [10. 11.  2. 13. 14.]
 [15. 16. 17.  3. 19.]
 [20. 21. 22. 23.  4.]
 [ 5. 26. 27. 28. 29.]
 [30.  6. 32. 33. 34.]
 [35. 36.  7. 38. 39.]
 [40. 41. 42.  8. 44.]
 [45. 46. 47. 48.  9.]
 [10. 51. 52. 53. 54.]
 [55. 11. 57. 58. 59.]
 [60. 61. 12. 63. 64.]
 [65. 66. 67. 13. 69.]
 [70. 71. 72. 73. 14.]] 

请注意 gen_outs 的对角线值已替换为 new_vals 中的值。

关于python - 分配 Keras 张量的索引条目,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54717678/

相关文章:

python - 向量化投资组合风险

python - 如何将数组复制/重复 N 次到新数组中?

python - 将 lambda 函数作为变量传递给 timeit 中的另一个函数

predict - 当keras进行主题预测时,mse的损失始终为0

python-3.x - 这是将 Keras 与 tensorflow 数据集一起使用时的错误吗?

python - 如何在 if 语句中正确使用 Or with strings

python - 在多个数据集上训练神经网络模型

python-3.x - keras连接多个层导致AttributeError : 'NoneType' object has no attribute '_inbound_nodes'

machine-learning - 在输入层中,input_dim和输入层上的节点数有什么区别?

python - 为什么 Keras ImageDataGenerator 会抛出内存错误?