python - 在给定转移矩阵的情况下有效地将转移应用于状态矩阵

标签 python numpy transition numba

我希望将状态更改应用于具有 k 个类别的大型分类矩阵 (M),其中我知道每个类别到 k (T) 中每个其他类别的转移概率

本质上,我希望能够有效地获取 M 中的每个元素,在给定 T 中的概率的情况下模拟状态变化,并用计算出的变化替换元素。

我尝试了一些解决方案:

  • 暴力嵌套 for 带索引的循环(太长了)
  • numba 辅助嵌套 for 循环(~500 毫秒,这对我来说太长了)
  • 为每个类别和替换预先计算的绘制(~400 毫秒)
import numpy as np


def categorical_transition(mat, t_mat, k=4):

    transformed_mat = mat.copy()
    cat_counts = np.bincount(mat.reshape(-1,))

    for i in range(k):
        rand_vec = np.random.multinomial(1, t_mat[i], cat_counts[i])

        choice = np.where(rand_vec)[1]

        transformed_mat[mat == i] = choice

    return transformed_mat


# load data
mat = np.random.choice(4, (16000, 256))
t_mat = np.random.random((4, 4))

# normalize transition matrix
for i in range(t_mat.shape[0]):
    t_mat[i] = t_mat[i] / t_mat[i].sum()

transformed_mat = categorical_transition(mat, t_mat)

此方法有效,但速度较慢,如有任何关于更有效的实现方法的建议,我将不胜感激

最佳答案

始终提供您到目前为止尝试过的所有实现

我尝试了一个简单的实现,如 here. 所述它应该比您的解决方案快 20-80 倍左右,具体取决于您有多少核心可用。

实现

@nb.njit(parallel=True)  
def categorical_transition_nb(mat_in, t_mat):
    mat=np.reshape(mat_in,-1)
    transformed_mat = np.empty_like(mat)
    for i in nb.prange(mat.shape[0]):
        rand_number=np.random.rand()
        probabilities=t_mat[mat[i],:]
        if rand_number<probabilities[0]:
            transformed_mat[i]=0
        else:
            for j in range(1,probabilities.shape[0]):
                if rand_number>=probabilities[j-1] and rand_number<probabilities[j]:
                    transformed_mat[i]=j

    return transformed_mat.reshape(mat_in.shape)

时间

import numpy as np
import numba as nb

# load data
mat = np.random.choice(4, (16_000,256))
t_mat = np.random.random((4, 4))

# normalize transition matrix
for i in range(t_mat.shape[0]):
    t_mat[i] = t_mat[i] / t_mat[i].sum()

t_mat_2=np.cumsum(t_mat,axis=1)
%timeit transformed_mat_2 = categorical_transition_nb(mat, t_mat_2)
21.7 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

关于python - 在给定转移矩阵的情况下有效地将转移应用于状态矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58666001/

相关文章:

python - 如何使用 symPy 和 numPy 将符号替换为矩阵

python - 分割成 block 后如何合并图像

javascript - 将类 ="transition"添加到网站上所有链接的脚本

Python csv 模块与 pandas.read_csv 和 Python xlrd 与 pandas.read_excel

递归函数中的Python返回命令

python - 使用 Pool 在 Python 中进行多重处理

html - 过渡高度不起作用

python - 在Python中访问元组迭代列表中特定索引的正确方法

python - numpy - 向量化函数 : apply_over_axes/apply_along_axis

javascript - d3 在转换中翻译旋转顺序