python - Numpy apply_along_axis 推断出错误的数据类型

标签 python numpy

我在使用 NumPy 时遇到以下问题:

代码:

import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)

输出:

['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!

我可以看到 NumPy 以某种方式从函数返回的第一个值推断出数据类型。我想出了以下解决方法 - 从具有明确声明的 dtype 而不是字符串的函数返回 NumPy 数组,并 reshape 结果:

def get_label_2(x):
    if x.sum() <= 10:
        return np.array(['SMALL'], dtype='|S5')
    else:
        return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])

你知道这个问题更优雅的解决方案吗?

最佳答案

你可以使用np.where:

arr1 = np.array([[1, 2], [30, 40]])
arr2 = np.array([[30, 40], [1, 2]])

print(np.where(arr1.sum(axis=1)<=10,'SMALL','BIG'))
print(np.where(arr2.sum(axis=1)<=10,'SMALL','BIG'))
['SMALL' 'BIG']
['BIG' 'SMALL']

在函数中:

def get_label(x, threshold, axis=1, label1='SMALL', label2='BIG'):
    return np.where(x.sum(axis=axis) <= threshold, label1, label2)

关于python - Numpy apply_along_axis 推断出错误的数据类型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46096748/

相关文章:

python - 将 Pandas 数据框中的整数二值化

python - 如何在 pandas Dataframe 中实现我自己的公式?

python - CompositeVideoClip 上的 Moviepy 返回错误 = TypeError : 'float' object cannot be interpreted as an integer

python - 聚类计算的有效距离

python - Fudge:@patch 当 from X import Y'ing 而不是 import X 时不起作用?

c# - 如何在 C# 中比较相邻的双端队列元素?

python - 在 Beautiful Soup 中查找并存储根的子代

python - 循环遍历 DF 列以删除包含西类牙语文本的行

python - numpy 中 [] 和 [[]] 的区别

python - numpy数组任意列之间的(内存)高效操作