python - 如何引导 numpy 数组的最里面的数组?

标签 python arrays numpy numpy-ndarray statistics-bootstrap

我有一个这些维度的 numpy 数组

data.shape(类别、模型、类型、事件):(10, 11, 50, 100)

现在我只想在最里面的数组(100)中进行替换样本。对于像这样的单个数组:

数据[0][0][0]

数组([ 40.448624 , 39.459843 , 33.76762 , 38.944622 , 21.407362 , 35.55499、68.5111、16.512974、21.118315、18.447166、 16.026619、21.596252、41.798622、63.01645、46.886642、 68.874756、17.472408、53.015724、85.41213、59.388977、 17.352108、61.161705、23.430847、20.203123、22.73194、 77.40547、43.02974、29.745787、21.50163、13.820962、 46.91466、41.43656、18.008326、13.122162、59.79936、 94.555305、24.798452、30.362497、13.629236、10.792178、 35.298515、20.904285、15.409604、20.567234、46.376335、 13.82727、17.970661、18.408686、21.987917、21.30094、 24.26776、27.399046、49.16879、21.831453、66.577、 15.524615、18.091696、24.346598、24.709772、19.068447、 24.221592、25.244864、52.865868、22.860783、23.586731、 18.928782、21.960285、74.77856、15.176119、20.795431、 14.3638935、35.937237、29.993324、30.848495、48.145336、 38.02541、101.15249、49.801117、38.123184、12.041505、 18.788296、20.53382、31.20367、19.76104、92.56279、 41.62944、23.53344、18.967432、14.781404、20.02018、 27.736559、16.108913、44.935062、12.629299、34.65672、 20.60169、21.779675、31.585844、23.768578、92.463196]、 数据类型=float32)

我可以使用以下内容进行替换样本:np.random.choice(data[0][0][0], 100),我将做了数千次。

数组([ 13.629236, 92.56279 , 21.960285, 20.567234, 21.50163 , 16.026619、20.203123、23.430847、16.512974、15.524615、 18.967432、22.860783、85.41213、21.779675、23.586731、 24.26776、66.577、20.904285、19.068447、21.960285、 68.874756、31.585844、23.586731、61.161705、101.15249、 59.79936、16.512974、43.02974、16.108913、24.26776、 23.430847、14.781404、40.448624、13.629236、24.26776、 19.068447、16.026619、16.512974、16.108913、77.40547、 12.629299、31.585844、24.798452、18.967432、14.781404、 23.430847、49.16879、18.408686、22.73194、10.792178、 16.108913、18.967432、12.041505、85.41213、41.62944、 31.20367、17.970661、29.745787、39.459843、10.792178、 43.02974、21.831453、21.50163、24.798452、30.362497、 21.50163、18.788296、20.904285、17.352108、41.798622、 18.447166、16.108913、19.068447、61.161705、52.865868、 20.795431、85.41213、49.801117、13.82727、18.928782、 41.43656、46.886642、92.56279、41.62944、18.091696、 20.60169、48.145336、20.53382、40.448624、20.60169、 23.586731、22.73194、92.56279、94.555305、22.73194、 17.352108, 46.886642, 27.399046, 18.008326, 15.176119], 数据类型=float32)

但是由于 np.random.choice 中没有 axis ,我该如何对所有数组(即(类别、模型、类型))执行此操作?或者循环遍历它是唯一的选择吗?

最佳答案

最快/最简单的答案是基于对数组的扁平版本进行索引:

def resampFlat(arr, reps):
    n = arr.shape[-1]

    # create an array to shift random indexes as needed
    shift = np.repeat(np.arange(0, arr.size, n), n).reshape(arr.shape)

    # get a flat view of the array
    arrflat = arr.ravel()
    # sample the array by generating random ints and shifting them appropriately
    return np.array([arrflat[np.random.randint(0, n, arr.shape) + shift] 
                     for i in range(reps)])

时间确认这是最快的答案。

时间

我测试了上面的 resampFlat 函数以及更简单的基于 for 循环的解决方案:

def resampFor(arr, reps):
    # store the shape for the return value
    shape = arr.shape
    # flatten all dimensions of arr except the last
    arr = arr.reshape(-1, arr.shape[-1])
    # preallocate the return value
    ret = np.empty((reps, *arr.shape), dtype=arr.dtype)
    # generate the indices of the resampled values
    idxs = np.random.randint(0, arr.shape[-1], (reps, *arr.shape))

    for rep,idx in zip(ret, idxs):
        # iterate over the resampled replicates
        for row,rowrep,i in zip(arr, rep, idx):
            # iterate over the event arrays within a replicate
            rowrep[...] = row[i]

    # give the return value the appropriate shape
    return ret.reshape((reps, *shape))

以及基于 Paul Panzer 奇特索引方法的解决方案:

def resampFancyIdx(arr, reps):
    idx = np.random.randint(0, arr.shape[-1], (reps, *data.shape))
    _, I, J, K, _ = np.ogrid[tuple(map(slice, (0, *arr.shape[:-1], 0)))]

    return arr[I, J, K, idx]

我使用以下数据进行了测试:

shape = ((10, 11, 50, 100))
data = np.arange(np.prod(shape)).reshape(shape)

以下是数组展平方法的结果:

%%timeit
resampFlat(data, 100)

1.25 s ± 9.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

for 循环方法的结果:

%%timeit
resampFor(data, 100)

1.66 s ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

来自 Paul 的精美索引:

%%timeit
resampFancyIdx(data, 100)

1.42 s ± 16.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

与我的预期相反,resampFancyIdx 击败了 resampFor,实际上我必须相当努力才能想出更好的东西。在这一点上,我真的很想更好地解释一下 C 级索引如何工作,以及为什么它如此高效。

关于python - 如何引导 numpy 数组的最里面的数组?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53236040/

相关文章:

python - python 将每个句子的第一个字母大写

javascript - 如何在数组中正确使用map()?

javascript - Angular2 从对象数组中获取值

python - 如何使用未知长度的列表对 np.array 进行切片?

python - 从多维 numpy 数组中选择随机数组并进行替换

python - 编译 Jinja2 AST 的小节

python - 在 xonsh 中,我如何从管道接收到 python 表达式?

python serval变量组合成一个字典?

arrays - 从 Mongoose 数组中删除元素

python - 如何在 python pandas 的同一列上进行分组并计算唯一值和某些值的计数作为聚合?