我有一大堆值打包在一个 4D numpy 数组中(x、y、z 中的数千个值数千次)。对于这些值中的每一个,我都需要来自 matplotlib.cm.ScalarMappable
对象的“颜色向量”(RGBA)。
我发现遍历这样一个数组变得相当慢,我想知道是否有一种方法可以通过采用不同的方法来显着加快速度。例如,是否可以将整个 numpy 数组(大于 2D)传递给 ScalarMappable
,以便以更 numpythonic 或矢量化的方式进行此操作?
我的 3D 案例代码示例:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import timeit
def get_colors_1(data,x,y,z):
colors = np.zeros( (x,y,z,4), dtype=np.float16)
for i in range(x):
for j in range(y):
for k in range(z):
colors[i,j,k,:] = m.to_rgba(data[i,j,k])
return colors
def get_colors_2(data,x,y,z):
colors = np.array([[[m.to_rgba(data[i,j,k]) for k in range(z)] for j in range(y)] for i in range(x)], dtype=np.float16)
return colors
def get_colors_3(data,x,y,z):
colors = np.zeros((x,y,z,4), dtype=np.float16)
for i in range(x):
colors[i,:,:,:] = m.to_rgba(data[i,:,:])
return colors
x, y, z = 30, 20, 10
data = np.random.rand(x,y,z)
cmap = matplotlib.cm.get_cmap('jet')
norm = matplotlib.colors.PowerNorm(vmin=0.0, vmax=1.0, gamma=2.5)
m = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
start_time = timeit.default_timer()
colors = get_colors_1(data,x,y,z)
elapsed = timeit.default_timer() - start_time
print('time elapsed: '+str(elapsed))
start_time = timeit.default_timer()
colors = get_colors_2(data,x,y,z)
elapsed = timeit.default_timer() - start_time
print('time elapsed: '+str(elapsed))
start_time = timeit.default_timer()
colors = get_colors_3(data,x,y,z)
elapsed = timeit.default_timer() - start_time
print('time elapsed: '+str(elapsed))
第三种方法(一次传递 2D 数组)显示出很大的性能提升,但我想知道是否可以将其进一步推进。
time elapsed: 0.5877857000014046
time elapsed: 0.5911024999986694
time elapsed: 0.004590500000631437
最佳答案
查看help ScalarMappable.to_rgba
,
def to_rgba(self, x, alpha=None, bytes=False, norm=True):
In the normal case, x is a 1D or 2D sequence of scalars, and the corresponding ndarray of rgba values will be returned, based on the norm and colormap set for this ScalarMappable. There is one special case, for handling images that are already rgb or rgba, such as might have been read from an image file. If x is an ndarray with 3 dimensions ...
是
sm = ScalarMappable( norm=None, cmap=cmap )
flat = data.reshape( -1 ) # a view
rgba = sm.to_rgba( flat, bytes=True, norm=False ).reshape( data.shape + (4,) )
快多了?
(来源:ScalarMappable to_rgba .)
关于python:从 ScalarMappable 获取整个 numpy 数组的颜色,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60058559/