python-3.x - 如何删除 3d 数组上包含 nan 元素的子数组并保留形状?

标签 python-3.x numpy multidimensional-array reshape numpy-ndarray

我有一个形状为(863, 923, 2)的稀疏数组,其中包含很多NAN:

[[[ 43.06010628 -11.01121568]
  [ 25.03068277  16.3949826 ]
  [-23.75853158 -10.95350074]
  ...
  [ 25.52110353   3.00428452]
  [ 32.66945663   9.76115107]
  [ 19.1341548    8.48547008]]

 [[ 19.08099208  11.27167832]
  [-29.4360534  -12.39131814]
  [ 11.24612069  14.38915742]
  ...
  [ 16.6897315   10.04601296]
  [ 30.09409518  17.09382562]
  [ -9.47312129  -9.57484782]]

 [[ 21.22006655  -5.01340343]
  [ 11.65512749   2.32398374]
  [-22.14668148 -11.05883399]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]

 ...

 [[ 32.32522443  -3.73563526]
  [ 30.88408144  -2.92184744]
  [ 37.44548043 -21.8209554 ]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]

 [[ 36.85471348  -7.86696711]
  [ 37.20204074  -6.32105844]
  [ 32.32522443  -3.73563526]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]

 [[ 34.21397091  -5.88930588]
  [ 35.88819735  -7.64992589]
  [ 35.48958094 -10.34708285]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]]

我想删除所有包含 nan 的子数组,同时保留数组的维数。我的理解是,数组的形状将更改为类似 (m, n, 2) 的形状,但在删除 NAN 后无法生成这样的数组。 这是我的尝试:

nonnanarr = arr[~np.isnan(arr).any(axis=-1)].reshape((863, -1, 2))

这是错误消息:

Traceback (most recent call last):
  File "c:\Users\username\Desktop\observables\my_script.py", line 167, in <module>
    main()
  File "c:\Users\username\Desktop\observables\my_script.py", line 104, in main
    time_stamp_num, agents_num, spatial_dimensions_num = dataframe_splitter()
  File "c:\Users\username\Desktop\observables\utilities.py", line 1351, in dataframe_splitter
    nonnan_arr = arr[~np.isnan(arr).any(axis=-1)].reshape(
ValueError: cannot reshape array of size 226512 into shape (863,newaxis,2)

最佳答案

如果您有一个 N 维数组,则需要沿 (N-1) 维减少掩码。

在您的例子中,您有 n = 3 维度,因此您有三种 ( comb(n, (n - 1)) ) 可能性。

例如,使用以下输入:

import numpy as np


arr = np.arange(3 * 4 * 5, dtype=np.float_).reshape((3, 4, 5))
print(arr[1, 1, 1])
# 26
arr[1, 1, 1] = np.nan
print(arr)
# [[[ 0.  1.  2.  3.  4.]
#   [ 5.  6.  7.  8.  9.]
#   [10. 11. 12. 13. 14.]
#   [15. 16. 17. 18. 19.]]

#  [[20. 21. 22. 23. 24.]
#   [25. nan 27. 28. 29.]
#   [30. 31. 32. 33. 34.]
#   [35. 36. 37. 38. 39.]]

#  [[40. 41. 42. 43. 44.]
#   [45. 46. 47. 48. 49.]
#   [50. 51. 52. 53. 54.]
#   [55. 56. 57. 58. 59.]]]

您可以减少(1, 2):

mask1 = np.isnan(arr).any(axis=(1, 2))
print(mask1)
# [False  True False]

print(arr[~mask1, :, :].shape)
# (2, 4, 5)

print(arr[~mask1, :, :])
# [[[ 0.  1.  2.  3.  4.]
#   [ 5.  6.  7.  8.  9.]
#   [10. 11. 12. 13. 14.]
#   [15. 16. 17. 18. 19.]]

#  [[40. 41. 42. 43. 44.]
#   [45. 46. 47. 48. 49.]
#   [50. 51. 52. 53. 54.]
#   [55. 56. 57. 58. 59.]]]

或在(0, 2)上:

mask2 = np.isnan(arr).any(axis=(0, 2))
print(mask2)
# [False  True False False]
print(arr[:, ~mask2, :].shape)
# (3, 3, 5)

print(arr[:, ~mask2, :])
# [[[ 0.  1.  2.  3.  4.]
#   [10. 11. 12. 13. 14.]
#   [15. 16. 17. 18. 19.]]

#  [[20. 21. 22. 23. 24.]
#   [30. 31. 32. 33. 34.]
#   [35. 36. 37. 38. 39.]]

#  [[40. 41. 42. 43. 44.]
#   [50. 51. 52. 53. 54.]
#   [55. 56. 57. 58. 59.]]]

或在(0, 1)上:

mask3 = np.isnan(arr).any(axis=(0, 1))
print(mask3)
# [False  True False False False]
print(arr[:, :, ~mask3].shape)
# (3, 4, 4)

print(arr[:, :, ~mask3])
# [[[ 0.  2.  3.  4.]
#   [ 5.  7.  8.  9.]
#   [10. 12. 13. 14.]
#   [15. 17. 18. 19.]]

#  [[20. 22. 23. 24.]
#   [25. 27. 28. 29.]
#   [30. 32. 33. 34.]
#   [35. 37. 38. 39.]]

#  [[40. 42. 43. 44.]
#   [45. 47. 48. 49.]
#   [50. 52. 53. 54.]
#   [55. 57. 58. 59.]]]

对于您的情况,如果您需要第三个维度保持不变,则不能减少 (0, 1),但可以减少 (1, 2) 中的任何一个和 (0, 2) 可以。您需要选择最适合您的。

关于python-3.x - 如何删除 3d 数组上包含 nan 元素的子数组并保留形状?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73584548/

相关文章:

python - 将列中的值附加到包含列表的另一列的开头

python - 为什么 .sum() 比 .any() 或 .max() 快?

numpy - Matplotlib 流线图,具有不中断或结束的流线

c - 如何在 C 中声明一个非常大的 3 维数组?

python - 如何使用另一个数据帧的 pandas 查询结果来过滤 pandas 数据帧

python - 如何知道字符串中是否有数字

python - Numpy 数组变量避免错误计算

Java : Why can't I declare an array as a simple Object?

javascript - 如何将展平多维数组(任意深度)的索引转换为原始索引?

python - KeyboardInterrupt 在不同的执行中表现不同