python - 使用 apply_along_axis 绘制

标签 python numpy matplotlib scipy

我有一个 3D ndarry 对象,它包含光谱数据(即空间 xy 维度和能量维度)。我想提取并绘制线图中每个像素的光谱。目前,我正在沿我感兴趣的轴使用 np.ndenumerate 来执行此操作,但速度很慢。我希望尝试 np.apply_along_axis,看看它是否更快,但我不断收到一个奇怪的错误。

什么有效:

# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt

ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume

# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel) 
# on a single line plot:

for (x,y) in np.ndenumerate(ar[:,:,1]):
    plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')

据我了解,这基本上是一个循环,其效率低于基于数组的方法,因此我想尝试一种使用 np.apply_along_axis 的方法,看看它是否更快。然而,这是我第一次尝试使用 Python,我仍在寻找它的工作原理,所以如果这个想法存在根本性缺陷,请纠正我!

我想尝试的:

# define a function to pass to apply_along_axis
def pa(y,x):
    if ~all(np.isnan(y)): # only do the plot if there is actually data there...
        plt.plot(x,y,alpha=0.15,color='black')
    return

# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)

# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!

产生的错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
     12 # pa(ar[1,1,:],ax)
     13 
---> 14 np.apply_along_axis(pa,2,ar,ax)

//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
    101         holdshape = outshape
    102         outshape = list(arr.shape)
--> 103         outshape[axis] = len(res)
    104         outarr = zeros(outshape, asarray(res).dtype)
    105         outarr[tuple(i.tolist())] = res

TypeError: object of type 'NoneType' has no len()

任何关于这里出了什么问题的想法/关于如何更好地做到这一点的建议都会很棒。

谢谢!

最佳答案

apply_along_axis 根据函数的输出创建一个新数组

您将返回 None(不返回任何内容)。因此错误。 Numpy 检查返回输出的长度以查看它对新数组是否有意义。

因为您不是根据结果构造新数组,所以没有理由使用 apply_along_axis。它不会更快。

但是,您当前的 ndenumerate 语句完全等同于:

import numpy as np
import matplotlib.pyplot as plt

ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')

一般来说,你可能想做这样的事情:

for pixel in ar.reshape(-1, ar.shape[-1]):
    plt.plot(x_values, pixel, ...)

这样您就可以轻松地迭代高光谱阵列中每个像素的光谱。


这里的瓶颈可能不是您使用数组的方式。在 matplotlib 中使用相同的参数单独绘制每条线会有些低效。

构造时间会稍长一些,但是 LineCollection 的渲染速度会快得多。 (基本上,使用 LineCollection 告诉 matplotlib 不要费心检查每条线的属性,只需将它们全部传递给低级渲染器以相同的方式绘制。你绕过了一堆单个 draw 调用支持大型对象的单个 draw。)

不利的一面是,代码的可读性会差一些。

稍后我将添加一个示例。

关于python - 使用 apply_along_axis 绘制,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22431985/

相关文章:

python - 行未添加到数据库中

php - 从 php 发布到 Flask

python - 高效的多精度数值数组

python - 应用 pandas df 中所有列的分布

python - 如何在 jupyter 笔记本中使用 python 向图像添加视觉注释?

python - Matplotlib 及其与 tkinter 的连接

python - 在 Django 中添加来自同一模型的两个外键字段

python - Django 与南方 : Adding new field but DatabaseError occurs "table already exists"

python - 获取 numpy ndarray 的一部分(对于任意维度)

python - 将结果写入 Excel