python - 提高 numpy 映射操作的性能

标签 python performance numpy

我有一个大小为 (4, X, Y) 的 numpy 数组,其中第一个维度代表一个 (R,G,B,A) 四元组。 我的目标是将每个 X*Y RGBA 四元组转置为 X*Y 浮点值,给定一个匹配它们的字典。

我目前的代码如下:

codeTable = {
    (255, 255, 255, 127): 5.5,
    (128, 128, 128, 255): 6.5,
    (0  , 0  , 0  , 0  ): 7.5,
}

for i in range(0, rows):
    for j in range(0, cols):
        new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)

其中 data 是一个大小为 (4, rows, cols) 的 numpy 数组,而 new_data 是大小为 (rows , 列).

代码运行良好,但需要相当长的时间。我应该如何优化那段代码?

这是一个完整的例子:

import numpy

codeTable = {
    (253, 254, 255, 127): 5.5,
    (128, 129, 130, 255): 6.5,
    (0  , 0  , 0  , 0  ): 7.5,
}

# test data
rows = 2
cols = 2
data = numpy.array([
    [[253, 0], [128,   0], [128,  0]],
    [[254, 0], [129, 144], [129,  0]],
    [[255, 0], [130, 243], [130,  5]],
    [[127, 0], [255, 120], [255,  5]],
])

new_data = numpy.zeros((rows,cols), numpy.float32)

for i in range(0, rows):
    for j in range(0, cols):
        new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)

# expected result for `new_data`:
# array([[  5.50000000e+00,   7.50000000e+00],
#        [  6.50000000e+00,  -9.99900000e+03],
#        [  6.50000000e+00,  -9.99900000e+03], dtype=float32)

最佳答案

这是一种返回预期结果的方法,但由于数据量如此之小,很难知道这种方法对您来说是否更快。但是,由于我避免了双重 for 循环,我想您会看到相当不错的加速。

import numpy
import pandas as pd


codeTable = {
    (253, 254, 255, 127): 5.5,
    (128, 129, 130, 255): 6.5,
    (0  , 0  , 0  , 0  ): 7.5,
}

# test data
rows = 3
cols = 2
data = numpy.array([
    [[253, 0], [128,   0], [128,  0]],
    [[254, 0], [129, 144], [129,  0]],
    [[255, 0], [130, 243], [130,  5]],
    [[127, 0], [255, 120], [255,  5]],
])

new_data = numpy.zeros((rows,cols), numpy.float32)

for i in range(0, rows):
    for j in range(0, cols):
        new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)

def create_output(data):
    # Reshape your two data sources to be a bit more sane
    reshaped_data = data.reshape((4, -1))
    df = pd.DataFrame(reshaped_data).T

    reshaped_codeTable = []
    for key in codeTable.keys():
        reshaped = list(key) + [codeTable[key]]
        reshaped_codeTable.append(reshaped)
    ct = pd.DataFrame(reshaped_codeTable)

    # Merge on the data, replace missing merges with -9999
    result = df.merge(ct, how='left')
    newest_data = result[4].fillna(-9999)

    # Reshape
    output = newest_data.reshape(rows, cols)
    return output

output = create_output(data)
print(output)
# array([[  5.50000000e+00,   7.50000000e+00],
#        [  6.50000000e+00,  -9.99900000e+03],
#        [  6.50000000e+00,  -9.99900000e+03])

print(numpy.array_equal(new_data, output))
# True

关于python - 提高 numpy 映射操作的性能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37630661/

相关文章:

Python:非凸网格的边界点

python - 用数字替换字符串 numpy 数组

python - 如何列出 ipython session 中加载的所有名称?

c++ - 我应该使用成员变量还是在函数内部声明变量?

python - Docker+Gunicorn+Flask,我不明白为什么我的设置不起作用

sql - 提高PostgreSQL查询效率-一对多,Count为1

Java 等于原始速度与对象速度

python - 检查当前行中的所有列值是否小于 Pandas 数据框中的前一行

python - 无法访问 Django 模板中的 UserProfile 模型字段。试过 {{ user.userprofile }}

python - 将多项式 w(z) 转换为 w((1-z)/(1+z))