python - 哪个更有效 : tf. where 或 element-wise multiplication?

标签 python tensorflow deep-learning keras loss

我正在实现一个损失函数,该函数将使用掩码张量 (M)0s and 1s 组成消除一些给定预测的损失值 (P)和地面真相(G)张量。

所以,我有两种可能的方法:

按元素乘法:
loss = K.sum(M * K.binary_crossentropy(G, P))
条件选择:

bin_ce = K.binary_crossentropy(G, P)
loss = K.sum(tf.where(tf.equal(M, 1), bin_ce, 0))

那么,在运行时间方面哪个更有效?

最佳答案

我做了基准测试,很明显乘法比条件选择好得多。

结果如下:

A chart is worth a thousand words..

一张图表值一千字。

基准代码:

import keras.backend as K
import tensorflow as tf
import numpy as np
import sys
import time
import matplotlib.pyplot as plt


def elm(G, P, M):
        return K.sum(M * K.binary_crossentropy(G, P))

def cond(G, P, M, t):
        C = K.variable(np.zeros((t, t)))
        bin_ce = K.binary_crossentropy(G, P)
        return K.sum(tf.where(tf.equal(M, 1), bin_ce, C))


s = [100, 1000, 10000, 100000]
elms = []
conds = []

for t in s:
        print t
        t = int(t)
        # number of 1s in mask
        n = int(t/2)

        M = np.zeros((t,t))
        P = np.random.rand(t, t)
        G = np.random.rand(t, t)

        for i in range(n):
                r = np.random.randint(0, t)
                c = np.random.randint(0, t)
                M[r,c] = 1

        M = K.variable(M)
        P = K.variable(P)
        G = K.variable(G)

        start_time = time.time()
        elm(G, P, M)
        elms.append(time.time() - start_time)

        start_time = time.time()
        cond(G, P, M, t)
        conds.append(time.time() - start_time)

print elms
print conds

# create plot
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.35
opacity = 0.8

rects1 = plt.bar(index, elms, bar_width,
                 alpha=opacity,
                 color='b',
                 label='Element-wise')

rects2 = plt.bar(index + bar_width, conds, bar_width,
                 alpha=opacity,
                 color='g',
                 label='Conditional')

plt.xlabel('Input tensor size')
plt.ylabel('Execution time (s)')
plt.title('')
plt.xticks(index + bar_width, ('100', '10e3', '10e4', '10e5'))
plt.legend()

plt.tight_layout()
plt.show()

关于python - 哪个更有效 : tf. where 或 element-wise multiplication?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46644796/

相关文章:

python - 在 MxNet-Gluon 中将 ROIPooling 层与预训练的 ResNet34 模型结合使用

python - 使用 grep 或 sed 合并两行

python - 无法填充 QTableWidget

python - 具有未知batch_size的Keras重复元素

tensorflow - 获取错误 "Resource exhausted: OOM when allocating tensor with shape[1800,1024,28,28] and type float on/job:localhost/..."

python - TypeError : tuple indices must be integers or slices, 未列出 - 加载模型 Keras 时

python - 在较大列表中处理可变大小的子列表

python - 3秒后如何创建线程拍照?

machine-learning - 卷积深度置信网络 (CDBN) 与卷积神经网络 (CNN)

tensorflow - 如何在 tensorflow tfrecords 中增加数据?