我有一个 3D ndarry 对象,它包含光谱数据(即空间 xy 维度和能量维度)。我想提取并绘制线图中每个像素的光谱。目前,我正在沿我感兴趣的轴使用 np.ndenumerate 来执行此操作,但速度很慢。我希望尝试 np.apply_along_axis
,看看它是否更快,但我不断收到一个奇怪的错误。
什么有效:
# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt
ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume
# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel)
# on a single line plot:
for (x,y) in np.ndenumerate(ar[:,:,1]):
plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')
据我了解,这基本上是一个循环,其效率低于基于数组的方法,因此我想尝试一种使用 np.apply_along_axis
的方法,看看它是否更快。然而,这是我第一次尝试使用 Python,我仍在寻找它的工作原理,所以如果这个想法存在根本性缺陷,请纠正我!
我想尝试的:
# define a function to pass to apply_along_axis
def pa(y,x):
if ~all(np.isnan(y)): # only do the plot if there is actually data there...
plt.plot(x,y,alpha=0.15,color='black')
return
# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)
# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!
产生的错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
12 # pa(ar[1,1,:],ax)
13
---> 14 np.apply_along_axis(pa,2,ar,ax)
//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
101 holdshape = outshape
102 outshape = list(arr.shape)
--> 103 outshape[axis] = len(res)
104 outarr = zeros(outshape, asarray(res).dtype)
105 outarr[tuple(i.tolist())] = res
TypeError: object of type 'NoneType' has no len()
任何关于这里出了什么问题的想法/关于如何更好地做到这一点的建议都会很棒。
谢谢!
最佳答案
apply_along_axis
根据函数的输出创建一个新数组。
您将返回 None
(不返回任何内容)。因此错误。 Numpy 检查返回输出的长度以查看它对新数组是否有意义。
因为您不是根据结果构造新数组,所以没有理由使用 apply_along_axis
。它不会更快。
但是,您当前的 ndenumerate
语句完全等同于:
import numpy as np
import matplotlib.pyplot as plt
ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')
一般来说,你可能想做这样的事情:
for pixel in ar.reshape(-1, ar.shape[-1]):
plt.plot(x_values, pixel, ...)
这样您就可以轻松地迭代高光谱阵列中每个像素的光谱。
这里的瓶颈可能不是您使用数组的方式。在 matplotlib
中使用相同的参数单独绘制每条线会有些低效。
构造时间会稍长一些,但是 LineCollection
的渲染速度会快得多。 (基本上,使用 LineCollection
告诉 matplotlib 不要费心检查每条线的属性,只需将它们全部传递给低级渲染器以相同的方式绘制。你绕过了一堆单个 draw
调用支持大型对象的单个 draw
。)
不利的一面是,代码的可读性会差一些。
稍后我将添加一个示例。
关于python - 使用 apply_along_axis 绘制,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22431985/