我有一个大小为 (4, X, Y) 的 numpy 数组,其中第一个维度代表一个 (R,G,B,A) 四元组。
我的目标是将每个 X*Y
RGBA 四元组转置为 X*Y
浮点值,给定一个匹配它们的字典。
我目前的代码如下:
codeTable = {
(255, 255, 255, 127): 5.5,
(128, 128, 128, 255): 6.5,
(0 , 0 , 0 , 0 ): 7.5,
}
for i in range(0, rows):
for j in range(0, cols):
new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)
其中 data
是一个大小为 (4, rows, cols)
的 numpy 数组,而 new_data
是大小为 (rows , 列)
.
代码运行良好,但需要相当长的时间。我应该如何优化那段代码?
这是一个完整的例子:
import numpy
codeTable = {
(253, 254, 255, 127): 5.5,
(128, 129, 130, 255): 6.5,
(0 , 0 , 0 , 0 ): 7.5,
}
# test data
rows = 2
cols = 2
data = numpy.array([
[[253, 0], [128, 0], [128, 0]],
[[254, 0], [129, 144], [129, 0]],
[[255, 0], [130, 243], [130, 5]],
[[127, 0], [255, 120], [255, 5]],
])
new_data = numpy.zeros((rows,cols), numpy.float32)
for i in range(0, rows):
for j in range(0, cols):
new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)
# expected result for `new_data`:
# array([[ 5.50000000e+00, 7.50000000e+00],
# [ 6.50000000e+00, -9.99900000e+03],
# [ 6.50000000e+00, -9.99900000e+03], dtype=float32)
最佳答案
这是一种返回预期结果的方法,但由于数据量如此之小,很难知道这种方法对您来说是否更快。但是,由于我避免了双重 for 循环,我想您会看到相当不错的加速。
import numpy
import pandas as pd
codeTable = {
(253, 254, 255, 127): 5.5,
(128, 129, 130, 255): 6.5,
(0 , 0 , 0 , 0 ): 7.5,
}
# test data
rows = 3
cols = 2
data = numpy.array([
[[253, 0], [128, 0], [128, 0]],
[[254, 0], [129, 144], [129, 0]],
[[255, 0], [130, 243], [130, 5]],
[[127, 0], [255, 120], [255, 5]],
])
new_data = numpy.zeros((rows,cols), numpy.float32)
for i in range(0, rows):
for j in range(0, cols):
new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)
def create_output(data):
# Reshape your two data sources to be a bit more sane
reshaped_data = data.reshape((4, -1))
df = pd.DataFrame(reshaped_data).T
reshaped_codeTable = []
for key in codeTable.keys():
reshaped = list(key) + [codeTable[key]]
reshaped_codeTable.append(reshaped)
ct = pd.DataFrame(reshaped_codeTable)
# Merge on the data, replace missing merges with -9999
result = df.merge(ct, how='left')
newest_data = result[4].fillna(-9999)
# Reshape
output = newest_data.reshape(rows, cols)
return output
output = create_output(data)
print(output)
# array([[ 5.50000000e+00, 7.50000000e+00],
# [ 6.50000000e+00, -9.99900000e+03],
# [ 6.50000000e+00, -9.99900000e+03])
print(numpy.array_equal(new_data, output))
# True
关于python - 提高 numpy 映射操作的性能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37630661/