python - Python 和 numpy 中两个变量循环的列表理解

标签 python list numpy

我必须从函数返回的值x、y创建一个2D numpy数组,以使用matplotlib中的contourf进行绘图,到目前为止我正在使用类似“C”的结构,它似乎是 Python 的效率非常低:

    dim_x = np.linspace(self.min_x, self.max_x, self.step)
    dim_y = np.linspace(self.min_y, self.max_y, self.step)
    X, Y = np.meshgrid(dim_x, dim_y)

    len_x = len(dim_x)
    len_y = len(dim_y)


    a = np.zeros([len_x, len_y], dtype=complex)

    for i, y in enumerate(dim_y):
        for j, x in enumerate(dim_x):
            a[i][j] = aux_functions.final_potential(complex(x, y), element_list)

cs = plt.contourf(X, Y, (a.real), 100)

如何以更 Pythonic 的方式完成此操作?

谢谢!

最佳答案

如果您可以将 final_pottial 重写为矢量化函数,那就太理想了。一个简单且可能过于明显的示例:

>>> dim_x = np.linspace(0, 2, 5)
>>> dim_y = np.linspace(0, 2, 5)
>>> X * Y
array([[ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.25,  0.5 ,  0.75,  1.  ],
       [ 0.  ,  0.5 ,  1.  ,  1.5 ,  2.  ],
       [ 0.  ,  0.75,  1.5 ,  2.25,  3.  ],
       [ 0.  ,  1.  ,  2.  ,  3.  ,  4.  ]])

但是,如果您确实不能这样做,您也可以矢量化:

>>> np.vectorize(lambda x, y: x * y + 2)(X, Y)
array([[ 2.  ,  2.  ,  2.  ,  2.  ,  2.  ],
       [ 2.  ,  2.25,  2.5 ,  2.75,  3.  ],
       [ 2.  ,  2.5 ,  3.  ,  3.5 ,  4.  ],
       [ 2.  ,  2.75,  3.5 ,  4.25,  5.  ],
       [ 2.  ,  3.  ,  4.  ,  5.  ,  6.  ]])

就您而言,它可能看起来像这样:

def wrapper(x, y): 
    return aux_functions.final_potential(complex(x, y), element_list)

a = np.vectorize(wrapper)(X, Y)

这可能比嵌套的 for 循环快一点,尽管 python 函数调用的开销会降低 numpy 的效率。在我过去所做的测试中,使用 vectorize 提供了适度的 5 倍加速。 (相比之下,纯 numpy 运算的速度提高了 100 倍或 1000 倍,如 X * Y 示例中所示。)

关于python - Python 和 numpy 中两个变量循环的列表理解,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/11017347/

相关文章:

python - 使用 Flask 将 URL 加载到 iFrame

python - 在 Python 中访问未绑定(bind)到变量的对象

python - 带有 filter() 和 order_by() 的查询集 distinct() 不起作用

javascript - 按元素内容选择元素

django - 如何在 Django 中一次将多个对象添加到 ManyToMany 关系中?

python - 如何解决 + : 'int' and 'tuple' because of trying to return 2 values with lambda? 不受支持的操作数类型

python - 移除/删除 NxM 矩阵每行中的每个最小值?

python - 在Python中的数据类上使用prefect 2.0流程

python - 长的 NumPy 数组无法完全打印?

Python 3 - ValueError : Found array with 0 sample(s) (shape=(0, 11)) 而 MinMaxScaler 要求至少为 1