我在使用 NumPy 时遇到以下问题:
代码:
import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)
输出:
['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!
我可以看到 NumPy 以某种方式从函数返回的第一个值推断出数据类型。我想出了以下解决方法 - 从具有明确声明的 dtype 而不是字符串的函数返回 NumPy 数组,并 reshape 结果:
def get_label_2(x):
if x.sum() <= 10:
return np.array(['SMALL'], dtype='|S5')
else:
return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])
你知道这个问题更优雅的解决方案吗?
最佳答案
你可以使用np.where
:
arr1 = np.array([[1, 2], [30, 40]])
arr2 = np.array([[30, 40], [1, 2]])
print(np.where(arr1.sum(axis=1)<=10,'SMALL','BIG'))
print(np.where(arr2.sum(axis=1)<=10,'SMALL','BIG'))
['SMALL' 'BIG']
['BIG' 'SMALL']
在函数中:
def get_label(x, threshold, axis=1, label1='SMALL', label2='BIG'):
return np.where(x.sum(axis=axis) <= threshold, label1, label2)
关于python - Numpy apply_along_axis 推断出错误的数据类型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46096748/