python - Tensorflow:tf.case 参数化可调用,案例列表在 for 循环中定义

标签 python lambda machine-learning tensorflow

我正在尝试在一组自动编码器的训练循环中实现一个案例分支:根据特定条件,只应更新一个特定的自动编码器。我一直在尝试使用 tf.case() 来实现这一点,但它没有像我预期的那样工作......

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), lambda: f(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))

winner_index:要更新的自动编码器索引

f(k_win):返回特定 AE 索引的所有更新可调用对象

Case_List:包含成对的 bool 值和参数化函数

Execution_List:可在执行循环中调用 sess.run()。

for 循环中的参数 k 应该定义 Case_List,特别是 'lambda: f(k)',但似乎在构建列表后,所有 'lambda: f(k)' 都设置为last k=N_Classes-1:效果是,只有最后一个自动编码器会被更新,而不是带有“winner_index”的那个。有谁知道这里发生了什么......?

谢谢。

最佳答案

问题是您正在定义的 lambda 正在使用全局变量 k,该变量在函数被调用时具有它在循环中采用的最后一个值( N_Class - 1).

一个更简单的例子:

lst = []
for k in range(10):
    lst.append(lambda: k * k)
print([lst_i() for lst_i in lst])

给予:

[81, 81, 81, 81, 81, 81, 81, 81, 81, 81]

代替:

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

This answer更好地解释了这个问题,并指出了一些克服这个问题的方法。在你的情况下,你可以这样做:

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), (lambda kk: lambda: f(kk))(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))

关于python - Tensorflow:tf.case 参数化可调用,案例列表在 for 循环中定义,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43258027/

相关文章:

machine-learning - 在贝叶斯网络中,节点 "instantiated"是什么意思

python - 如何处理分类器中不平衡的类?

Python:如何使用 1 个 lambda 函数进行多种浮点格式设置

java - 在 Java 8+ 中,流中只允许使用单参数方法引用

java - 如何让 printf %x 类似于 python

python - sqlite3 "OperationalError: near "(": syntax error" python

c# - Lambda 表达式和 InvokeOperation

python - 如何批量训练具有多个输入的模型?

python - 在 Keras 中保存模型时引发 'Unable to create group (name already exists)' 错误

python - 寻找一种同时循环遍历两个不同长度列表的优雅方法