我正在尝试在一组自动编码器的训练循环中实现一个案例分支:根据特定条件,只应更新一个特定的自动编码器。我一直在尝试使用 tf.case() 来实现这一点,但它没有像我预期的那样工作......
def f(k_win):
update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])
return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)]
winner_index = tf.argmin(Cost_Alpha_List, 0)
Case_List = []
for k in range(N_Class):
Case = (tf.equal(winner_index,k), lambda: f(k))
Case_List.append(Case)
Execution_List = tf.case(Case_List, lambda: f(0))
winner_index:要更新的自动编码器索引
f(k_win):返回特定 AE 索引的所有更新可调用对象
Case_List:包含成对的 bool 值和参数化函数
Execution_List:可在执行循环中调用 sess.run()。
for 循环中的参数 k 应该定义 Case_List,特别是 'lambda: f(k)',但似乎在构建列表后,所有 'lambda: f(k)' 都设置为last k=N_Classes-1:效果是,只有最后一个自动编码器会被更新,而不是带有“winner_index”的那个。有谁知道这里发生了什么......?
谢谢。
最佳答案
问题是您正在定义的 lambda 正在使用全局变量 k
,该变量在函数被调用时具有它在循环中采用的最后一个值( N_Class - 1
).
一个更简单的例子:
lst = []
for k in range(10):
lst.append(lambda: k * k)
print([lst_i() for lst_i in lst])
给予:
[81, 81, 81, 81, 81, 81, 81, 81, 81, 81]
代替:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
This answer更好地解释了这个问题,并指出了一些克服这个问题的方法。在你的情况下,你可以这样做:
def f(k_win):
update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])
return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)]
winner_index = tf.argmin(Cost_Alpha_List, 0)
Case_List = []
for k in range(N_Class):
Case = (tf.equal(winner_index,k), (lambda kk: lambda: f(kk))(k))
Case_List.append(Case)
Execution_List = tf.case(Case_List, lambda: f(0))
关于python - Tensorflow:tf.case 参数化可调用,案例列表在 for 循环中定义,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43258027/