python - Cython Gibbs 采样器比 numpy 采样器稍慢

标签 python numpy cython

我已经实现了一个 Gibbs 采样器来生成纹理图像。根据beta参数(shape(4)数组),我们可以生成各种纹理。

这是我使用 Numpy 的初始函数:

def gibbs_sampler(img_label, betas, burnin, nb_samples):
    nb_iter = burnin + nb_samples

    lst_samples = []

    labels = np.unique(img)

    M, N = img.shape
    img_flat = img.flatten()

    # build neighborhood array by means of numpy broadcasting:
    m, n = np.ogrid[0:M, 0:N]

    top_left, top, top_right =   m[0:-2, :]*N + n[:, 0:-2], m[0:-2, :]*N + n[:, 1:-1]  , m[0:-2, :]*N + n[:, 2:]
    left, pix, right = m[1:-1, :]*N + n[:, 0:-2],  m[1:-1, :]*N + n[:, 1:-1], m[1:-1, :]*N + n[:, 2:]
    bottom_left, bottom, bottom_right = m[2:, :]*N + n[:, 0:-2],  m[2:, :]*N + n[:, 1:-1], m[2:, :]*N + n[:, 2:]

    mat_neigh = np.dstack([pix, top, bottom, left, right, top_left, bottom_right, bottom_left, top_right])

    mat_neigh = mat_neigh.reshape((-1, 9))    
    ind = np.arange((M-2)*(N-2))  

    # loop over iterations
    for iteration in np.arange(nb_iter):

        np.random.shuffle(ind)

        # loop over pixels
        for i in ind:                  

            truc = map(functools.partial(lambda label, img_flat, mat_neigh : 1-np.equal(label, img_flat[mat_neigh[i, 1:]]).astype(np.uint), img_flat=img_flat, mat_neigh=mat_neigh), labels)
            # bidule is of shape (4, 2, labels.size)
            bidule = np.array(truc).T.reshape((-1, 2, labels.size))

            # theta is of shape (labels.size, 4) 
            theta = np.sum(bidule, axis=1).T
            # prior is thus an array of shape (labels.size)
            prior = np.exp(-np.dot(theta, betas))

            # sample from the posterior
            drawn_label = np.random.choice(labels, p=prior/np.sum(prior))

            img_flat[(i//(N-2) + 1)*N + i%(N-2) + 1] = drawn_label


        if iteration >= burnin:
            print('Iteration %i --> sample' % iteration)
            lst_samples.append(copy.copy(img_flat.reshape(M, N)))

        else:
            print('Iteration %i --> burnin' % iteration)

    return lst_samples

我们无法摆脱任何循环,因为它是一种迭代算法。因此,我尝试通过使用 Cython(使用静态类型)来加快速度:

from __future__ import division
import numpy as np
import copy
cimport numpy as np
import functools
cimport cython

INTTYPE = np.int
DOUBLETYPE = np.double

ctypedef np.int_t INTTYPE_t
ctypedef  np.double_t DOUBLETYPE_t

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)


def func_for_map(label, img_flat,  mat_neigh, i):

   return  (1-np.equal(label, img_flat[mat_neigh[i, 1:]])).astype(INTTYPE)


def gibbs_sampler(np.ndarray[INTTYPE_t, ndim=2] img_label, np.ndarray[DOUBLETYPE_t, ndim=1] betas, INTTYPE_t burnin=5, INTTYPE_t nb_samples=1):


    assert img_label.dtype == INTTYPE and betas.dtype== DOUBLETYPE

    cdef unsigned int nb_iter = burnin + nb_samples 

    lst_samples = list()

    cdef np.ndarray[INTTYPE_t, ndim=1] labels
    labels = np.unique(img_label)

    cdef unsigned int M, N
    M = img_label.shape[0]
    N = img_label.shape[1]

    cdef np.ndarray[INTTYPE_t, ndim=1] ind     
    ind = np.arange((M-2)*(N-2), dtype=INTTYPE)

    cdef np.ndarray[INTTYPE_t, ndim=1] img_flat
    img_flat = img_label.flatten()


    # build neighborhood array:
    cdef np.ndarray[INTTYPE_t, ndim=2] m
    cdef np.ndarray[INTTYPE_t, ndim=2] n


    m = (np.ogrid[0:M, 0:N][0]).astype(INTTYPE)
    n = (np.ogrid[0:M, 0:N][1]).astype(INTTYPE)



    cdef np.ndarray[INTTYPE_t, ndim=2] top_left, top, top_right, left, pix, right, bottom_left, bottom, bottom_right

    top_left, top, top_right =   m[0:-2, :]*N + n[:, 0:-2], m[0:-2, :]*N + n[:, 1:-1]  , m[0:-2, :]*N + n[:, 2:]
    left, pix, right = m[1:-1, :]*N + n[:, 0:-2],  m[1:-1, :]*N + n[:, 1:-1], m[1:-1, :]*N + n[:, 2:]
    bottom_left, bottom, bottom_right = m[2:, :]*N + n[:, 0:-2],  m[2:, :]*N + n[:, 1:-1], m[2:, :]*N + n[:, 2:]

    cdef np.ndarray[INTTYPE_t, ndim=3] mat_neigh_init
    mat_neigh_init = np.dstack([pix, top, bottom, left, right, top_left, bottom_right, bottom_left, top_right])

    cdef np.ndarray[INTTYPE_t, ndim=2] mat_neigh
    mat_neigh = mat_neigh_init.reshape((-1, 9))    

    cdef unsigned int i
    truc = list()
    cdef np.ndarray[INTTYPE_t, ndim=3] bidule
    cdef np.ndarray[INTTYPE_t, ndim=2] theta
    cdef np.ndarray[DOUBLETYPE_t, ndim=1] prior
    cdef unsigned int drawn_label, iteration       



    # loop over ICE iterations
    for iteration in np.arange(nb_iter):

        np.random.shuffle(ind) 

        # loop over pixels        
        for i in ind:            

            truc = map(functools.partial(func_for_map, img_flat=img_flat, mat_neigh=mat_neigh, i=i), labels)                        


            bidule = np.array(truc).T.reshape((-1, 2, labels.size)).astype(INTTYPE)            


            theta = np.sum(bidule, axis=1).T

            # ok so far

            prior = np.exp(-np.dot(theta, betas)).astype(DOUBLETYPE)
#            print('ok after prior') 
#            return 0
            # sample from the posterior
            drawn_label = np.random.choice(labels, p=prior/np.sum(prior))

            img_flat[(i//(N-2) + 1)*N + i%(N-2) + 1] = drawn_label


        if iteration >= burnin:
            print('Iteration %i --> sample' % iteration)
            lst_samples.append(copy.copy(img_flat.reshape(M, N)))

        else:
            print('Iteration %i --> burnin' % iteration)   



    return lst_samples

但是,我得到的计算时间几乎相同,numpy 版本比 Cython 版本略快。

因此,我正在尝试改进 Cython 代码。

编辑:

对于两个函数(Cython 和非 Cython): 我已经替换了:

truc = map(functools.partial(lambda label, img_flat, mat_neigh : 1-np.equal(label, img_flat[mat_neigh[i, 1:]]).astype(np.uint), img_flat=img_flat, mat_neigh=mat_neigh), labels)

通过广播:

truc = 1-np.equal(labels[:, None], img_flat[mat_neigh[i, 1:]][None, :])

所有 np.arangerange 计算,并且先验的计算现在通过 np.einsum 完成,正如 Divakar 所建议的。

这两个函数都比以前更快,但 Python 函数仍然比 Cython 函数快一些。

最佳答案

我已经运行了 Cython in annotated mode在您的来源上,并查看了结果。也就是说,将其保存在 q.pyx 中后,我运行了

cython -a q.pyx
firefox q.html

(当然,使用任何你想要的浏览器)。

代码颜色为深黄色,表明就Cython而言,代码远非静态类型。 AFAICT,这分为两类。

在某些情况下,您最好静态输入代码:

  1. for iteration in np.arange(nb_iter):for i in ind: 中,您需要为每次迭代支付大约 30 行 C 代码。参见 here如何在 Cython 中高效访问 numpy 数组。

  2. truc = map(functools.partial(func_for_map, img_flat=img_flat, mat_neigh=mat_neigh, i=i), labels) 中,您并没有真正从静态中获得任何好处打字。我建议您cdef 函数func_for_map,然后自己在循环中调用它。

在其他情况下,您正在调用 numpy 向量化函数,例如,theta = np.sum(bidule, axis=1).T, prior = np.exp(-np.dot(theta, betas)).astype(DOUBLETYPE) 等。在这些情况下,Cython 确实没有太大优势。

关于python - Cython Gibbs 采样器比 numpy 采样器稍慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40246152/

相关文章:

python - Cython调用lapack,报错: `Cannot take address of Python variable'

python - 静态 Sitemap.xml Django

python - 如果发生异常,则移除 JSON 文件

python - 使用协程在 python 中实现责任链模式

python - 从 Pandas 中删除非重复行

python - 在 Python 中基于地类网格求和土地面积网格

python - 如何在 MatPlotLib (NumPy) 中绘制多维数据的轮廓?

0 和 1 的 Python 随机数组

python - Cython 作为 Python 到 C 转换器的示例程序

python - 使用 setuptools 创建调用外部 C 库的 cython 包