我是 cython 的新手,正在尝试将 python 类转换为 cython。我不知道我应该如何在 instance Da 中定义参数 z
,以它可以同时处理 numpy.array 或仅处理的方式单个 float
数字。
cdef class Cosmology(object):
cdef double omega_m, omega_lam, omega_c
def __init__(self,double omega_m=0.3,double omega_lam=0.7):
self.omega_m = omega_m
self.omega_lam = omega_lam
self.omega_c = (1. - omega_m - omega_lam)
cpdef double a(self, double z):
cdef double a
return 1./(1+z)
cpdef double E(self, double a):
cdef double E
return (self.omega_m*a**(-3) + self.omega_c*a**(-2) + self.omega_lam)**0.5
cpdef double __angKernel(self, double x):
cdef __angKernel:
"""Integration kernel"""
return self.E(x**-1)**-1
cpdef double Da(self, double z, double z_ref=0):
cdef double Da
if isinstance(z, np.ndarray):
da = np.zeros_like(z)
for i in range(len(da)):
da[i] = self.Da(z[i], z_ref)
return da
else:
if z < 0:
raise ValueError("Redshift z must not be negative")
if z < z_ref:
raise ValueError("Redshift z must not be smaller than the reference redshift")
d = integrate.quad(self.__angKernel, z_ref+1, z+1,epsrel=1.e-6, epsabs=1.e-12)
rk = (abs(self.omega_c))**0.5
if (rk*d[0] > 0.01):
if self.omega_c > 0:
d[0] = sinh(rk*d[0])/rk
if self.omega_c < 0:
d[0] = sin(rk*d[0])/rk
return d[0]/(1+z)
我也想知道我是否将所有参数正确转换为cython参数?我想改变我原来的python代码来提高计算速度。我认为我的代码中的瓶颈之一应该是 integrate.quad
。 cython
中是否有此函数的任何替代项有助于加快我的代码的性能?
cdef class halo_positions(object):
cdef double x = None
cdef double y = None
def __init__(self,numpy.ndarray[double, ndim=1] positions):
self.x = positions[0]
self.y = positions[1]
如果我想将数组传递给 halo_positions
实例,这是正确的方法吗?
最佳答案
如果你的类被定义为 cdef
它将只能在 Cython 中访问(而不是在 Python 中)使得使用 cpdef
和 def 是不必要的并且效率不高
用于类方法。您可以将它们全部转换为 cdef
。
当您告诉 z
是 double
时,它将只接受 double
。如果你希望这个参数是两种不同的类型,你应该保持它的类型未声明,但是当 z
是一个 ndarray
时,这将直接影响循环性能。
或者,您可以使用 double *
并传递它的大小,当大小为 1
时,它是一个 double,当大小为 >1
一个数组。该功能将是:
cdef double Da(self, int size, double *z, double z_ref=0):
if size>1:
da = np.zeros(size)
for i in range(size):
da[i] = self.Da(1, &z[i], z_ref)
return da
else:
if z[0] < 0:
raise ValueError("Redshift z must not be negative")
if z[0] < z_ref:
raise ValueError("Redshift z must not be smaller than the reference redshift")
d = integrate.quad(self.__angKernel, z_ref+1, z[0]+1,
epsrel=1.e-6, epsabs=1.e-12)
rk = (abs(self.omega_c))**0.5
if (rk*d[0] > 0.01):
if self.omega_c > 0:
d[0] = sinh(rk*d[0])/rk
if self.omega_c < 0:
d[0] = sin(rk*d[0])/rk
return d[0]/(1+z[0])
关于python - 定义cython类内部函数的参数和cython中的快速积分计算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24085711/