我是 Keras 的初学者,所以我提前为任何普遍的理解不足表示歉意。
我想根据存储在另一个张量中的索引手动设置我的 Keras 张量的一些值。我相信我了解如何使用 tf.gather_nd
访问张量的条目(我在下面未经测试的尝试),而且我想我知道我只能设置变量的值而不是张量。
为了清楚起见,这发生在 GAN 的生成和识别阶段之间。
gen_out = generator(inputs)
indices_to_reset = Input(shape=(1,),dtype='int32')
new_values = Input(shape=(1,), dtype='int32')
batch_size = K.shape(x)[0]
idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))
idx = K.stack((idx_0, indices_to_reset), axis=0)
grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)
# Doesn't work
# gen_out[:,indices_to_reset] = new_values
updated_gen_out = ???
最佳答案
如果将所有内容都转换为 one_hot 张量并使用 switch 会容易得多:
(记得把所有的操作都放在lambda层里面,不然你会出问题的)
def replace_values(x):
outs, indices, values = x
#this is due to a strange bug between lambda and integers....
indices = K.cast(indices, 'int32')
#create one_hot indices
one_hot_indices = K.one_hot(indices, size) #size is the size of gen_out
one_hot_indices = K.batch_flatten(one_hot_indices)
#have the desired values at their correct positions
values_to_use = one_hot_indices * new_values
#if values are 0, use gen_out, else use values
return K.switch(K.equal(values_to_use, 0), outs, values_to_use)
updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])
Warning:
new_values
cannot be integer, they must be the same type asgen_out
.
虚拟示例:
import numpy as np
from keras.layers import *
from keras.models import Model
size = 5
batch_size = 15
gen_out = Input((size,))
indices_to_reset = Input((1,), dtype='int32')
new_values = Input((1,))
def replace_values(x):
outs, indices, values = x
print(K.int_shape(outs))
print(K.int_shape(indices))
#this is due to a strange bug between lambda and integers....
indices = K.cast(indices, 'int32')
one_hot_indices = K.one_hot(indices, size)
print(K.int_shape(one_hot_indices))
one_hot_indices = K.batch_flatten(one_hot_indices)
print(K.int_shape(one_hot_indices))
values_to_use = one_hot_indices * new_values
print(K.int_shape(values_to_use))
return K.switch(K.equal(values_to_use, 0), outs, values_to_use)
updated_gen_out = Lambda(replace_values)([gen_out, indices_to_reset, new_values])
model = Model([gen_out,indices_to_reset,new_values], updated_gen_out)
gen_outs = np.arange(batch_size * size).reshape((batch_size, size))
indices = np.concatenate([np.arange(5)]*3, axis=0)
new_vals = np.arange(15).reshape((15,1))
print('\n\ngen outs')
print(gen_outs)
print('\n\nindices')
print(indices)
print('\n\nvalues')
print(new_vals)
print('\n\n results')
print(model.predict([gen_outs, indices, new_vals]))
输出:
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
(None, 5)
(None, 1)
(None, 1, 5)
(None, None)
(None, None)
gen outs
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]
[25 26 27 28 29]
[30 31 32 33 34]
[35 36 37 38 39]
[40 41 42 43 44]
[45 46 47 48 49]
[50 51 52 53 54]
[55 56 57 58 59]
[60 61 62 63 64]
[65 66 67 68 69]
[70 71 72 73 74]]
indices
[0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]
values
[[ 0]
[ 1]
[ 2]
[ 3]
[ 4]
[ 5]
[ 6]
[ 7]
[ 8]
[ 9]
[10]
[11]
[12]
[13]
[14]]
results
[[ 0. 1. 2. 3. 4.]
[ 5. 1. 7. 8. 9.]
[10. 11. 2. 13. 14.]
[15. 16. 17. 3. 19.]
[20. 21. 22. 23. 4.]
[ 5. 26. 27. 28. 29.]
[30. 6. 32. 33. 34.]
[35. 36. 7. 38. 39.]
[40. 41. 42. 8. 44.]
[45. 46. 47. 48. 9.]
[10. 51. 52. 53. 54.]
[55. 11. 57. 58. 59.]
[60. 61. 12. 63. 64.]
[65. 66. 67. 13. 69.]
[70. 71. 72. 73. 14.]]
请注意 gen_outs
的对角线值已替换为 new_vals
中的值。
关于python - 分配 Keras 张量的索引条目,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54717678/