python - 将多个函数应用于数组的每一行

标签 python arrays performance numpy

我有一个 numpy 数组,它只有几个非零条目,可以是正数也可以是负数。例如。像这样:

myArray = np.array([[ 0.        ,  0.        ,  0.        ],
       [ 0.32, -6.79,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  1.5        ,  0.        ],
       [ 0.        ,  0.        , -1.71]])

最后,我想收到一个列表,其中该列表的每个条目对应于 myArray 的一行,并且是函数输出的累积乘积,该输出取决于 myArray 和另一个列表的相应行的条目(在下面的示例称为 l)。 各个项取决于 myArray 条目的符号:当它为正时,我应用“funPos”,当它为负时,我应用“funNeg”,如果条目为 0,则该项将为 1。所以在示例中上面的数组将是:

output = [1*1*1 , 
         funPos(0.32, l[0])*funNeg(-6.79,l[1])*1, 
         1*1*1, 
         1*funPos(1.5, l[1])*1, 
         1*1*funNeg(-1.71, l[2])]

我如下所示实现了它,它给了我想要的输出(注意:这只是一个高度简化的玩具示例;实际的矩阵要大得多,函数也要复杂得多)。我遍历数组的每一行,如果行的总和为 0,我不需要做任何计算,输出只是 1。如果它不等于 0,我遍历这一行,检查符号每个值并应用适当的函数。

import numpy as np
def doCalcOnArray(Array1, myList):

    output = np.ones(Array1.shape[0]) #initialize output

    for indRow,row in enumerate(Array1):

    if sum(row) != 0: #only then calculations are needed
        tempProd = 1. #initialize the product that corresponds to the row
        for indCol, valCol in enumerate(row):

        if valCol > 0:
            tempVal = funPos(valCol, myList[indCol])

        elif valCol < 0:
            tempVal = funNeg(valCol, myList[indCol])

        elif valCol == 0:
            tempVal = 1

        tempProd = tempProd*tempVal

        output[indRow] = tempProd

    return output 

def funPos(val1,val2):
    return val1*val2

def funNeg(val1,val2):
    return val1*(val2+1)

myArray = np.array([[ 0.        ,  0.        ,  0.        ],
       [ 0.32, -6.79,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  1.5        ,  0.        ],
       [ 0.        ,  0.        , -1.71]])     

l = [1.1, 2., 3.4]

op = doCalcOnArray(myArray,l)
print op

输出是

[ 1.      -7.17024  1.       3.      -7.524  ]

这是想要的。
我的问题是是否有更有效的方法来做到这一点,因为这对于大型阵列来说相当“昂贵”。

编辑: 我接受了 gabhijit 的回答,因为他提出的纯 numpy 解决方案似乎是我正在处理的数组中最快的解决方案。请注意,RaJa 也有一个很好的工作解决方案,它需要 panda,而且 dave 的解决方案工作正常,可以作为一个很好的例子来说明如何使用生成器和 numpy 的“apply_along_axis”。

最佳答案

这是我尝试过的方法 - 使用 reduce、map。我不确定这有多快 - 但这是您想要做的吗?

编辑 4:最简单且最易读 - 使 l 成为一个 numpy 数组,然后大大简化 where

import numpy as np
import time

l = np.array([1.0, 2.0, 3.0])

def posFunc(x,y):
    return x*y

def negFunc(x,y):
    return x*(y+1)

def myFunc(x, y):
    if x > 0:
        return posFunc(x, y)
    if x < 0:
        return negFunc(x, y)
    else:
        return 1.0

myArray = np.array([
        [ 0.,0.,0.],
        [ 0.32, -6.79,  0.],
        [ 0.,0.,0.],
        [ 0.,1.5,0.],
        [ 0.,0., -1.71]])

t1 = time.time()
a = np.array([reduce(lambda x, (y,z): x*myFunc(z,l[y]), enumerate(x), 1) for x in myArray])
t2 = time.time()
print (t2-t1)*1000000
print a

基本上让我们只看最后一行,它说在 enumerate(xx) 中累积乘法,从 1(reduce 的最后一个参数)开始。 myFunc 只是获取 myArray(row) 中的元素和 l 中的元素@index row 并根据需要将它们相乘。

我的输出与你的不一样 - 所以我不确定这是否正是你想要的,但也许你可以遵循逻辑。

此外,我不太确定这对于大型阵列来说有多快。

编辑:以下是执行此操作的“纯 numpy 方式”。

my = myArray # just for brevity

t1 = time.time() 
# First set the positive and negative values
# complicated - [my.itemset((x,y), posFunc(my.item(x,y), l[y])) for (x,y) in zip(*np.where(my > 0))]
# changed to 
my = np.where(my > 0, my*l, my)
# complicated - [my.itemset((x,y), negFunc(my.item(x,y), l[y])) for (x,y) in zip(*np.where(my < 0))]
# changed to 
my = np.where(my < 0, my*(l+1), my)
# print my - commented out to time it.

# Now set the zeroes to 1.0s
my = np.where(my == 0.0, 1.0, my)
# print my  - commented out to time it

a = np.prod(my, axis=1)
t2 = time.time()
print (t2-t1)*1000000

print a

让我尽量解释 zip(*np.where(my != 0)) 部分。 np.where 简单地返回两个 numpy 数组,第一个数组是行的索引,第二个数组是与条件匹配的列的索引 (my != 0) 在这种情况下.我们采用这些索引的元组,然后使用 array.itemsetarray.item,幸运的是,列索引对我们来说是免费的,所以我们可以只采用元素@ 列表 l 中的那个索引。这应该比以前更快(并且可读性提高了几个数量级!!)。需要 timeit 来确定它是否确实是。

编辑 2:不必单独调用正负可以通过一次调用完成 np.where(my != 0)

关于python - 将多个函数应用于数组的每一行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30507009/

相关文章:

javascript - 如何根据多变量对数据进行分类

jquery - 使用 JQuery 缩放锯齿状图像

python - 在 Python 中,字符串序列到底是什么? (或者 Glib 错误?)

python - 部署 Django 应用程序时出现问题

PHP 数组多重排序

javascript - 使用多维数组创建表 - 错误 : Cannot set property of undefined

java - 为什么反射慢?

javascript - 使用匿名函数会影响性能吗?

python - 如何使用 pandas 数据帧的正则表达式仅提取一个捕获组?

python - 在 TfidfVectorizer 中删除法语和英语中的停用词