python - 重新排列基于python中列表的数组元素

标签 python numpy pytorch

我有一个大小为 2, 1403 的二维数组 a 和一个包含 2 个列表的列表 b

a.shape = (2, 1403) # a 是二维数组,每一行都有唯一的元素。

len(b) = 2 # b 是列表

len(b[0]), len(b[1]) = 415, 452 # 这里 b 中的列表也有唯一元素

b[0] 和 b[1] 中的所有元素分别出现在 a[0] 和 a[1]

现在我想根据 b 的元素重新排列 a 的元素。我想重新排列 b[0] 中也存在于 a[0] 中的所有元素应该出现在 a[0] 的结尾,意味着新的 a 应该是 a[0][:-len(b[0])] = b[0],类似地 a[1][:-len(b[1])] = b[1]

玩具示例

a[[1,2,3,4,5,6,7,8,9,10,11,12],[1,2, 3,4,5,6,7,8,9,10,11,12]

b 有像 [[5, 9, 10], [2, 6, 8, 9, 11]] 这样的元素

new_a 变为 [[1,2,3,4,6,7,8,11,12,5,9,10], [1,3,4,5 ,7,10,12,2,6,8,9,11]]

我写了一个循环遍历所有元素的代码,它变得非常慢,如下所示

a_temp = []
remove_temp = []
for i, array in enumerate(a):
    a_temp_inner = []
    remove_temp_inner = []
    for element in array:
        if element not in b[i]:
            a_temp_inner.append(element) # get all elements first which are not present in b
        else:
            remove_temp_inner.append(element) #if any element present in b, remove it from main array

    a_temp.append(a_temp_inner)
    remove_temp.append(b_temp_inner)

a_temp = torch.tensor(a_temp)
remove_temp = torch.tensor(remove_temp)
a = torch.cat((a_temp, remove_temp), dim = 1) 

任何人都可以帮助我实现一些比这更好的更快的实现

最佳答案

假设 a 是一个 np.arrayb 是一个 list 你可以使用

np.array([np.concatenate((i[~np.in1d(i, j)], j)) for i, j in zip(a,b)])

输出

array([[ 1,  2,  3,  4,  6,  7,  8, 11, 12,  5,  9, 10],
       [ 1,  3,  4,  5,  7, 10, 12,  2,  6,  8,  9, 11]])

如果b 包含空列表,则可以进行微优化

np.array([np.concatenate((i[~np.in1d(i, j)], j)) if j else i for i, j in zip(a,b)])

在我的基准测试中,对于少于 ~100 个元素的 np.arrays 转换 .tolist()np.concatenate 快/p>

np.array([i[~np.in1d(i, j)].tolist() + j for i, j in zip(a,b)])

此解决方案的数据示例和导入

import numpy as np

a = np.array([
        [1,2,3,4,5,6,7,8,9,10,11,12],
        [1,2,3,4,5,6,7,8,9,10,11,12]
    ])
b = [[5, 9, 10],
     [2, 6, 8, 9, 11]]

关于python - 重新排列基于python中列表的数组元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71542505/

相关文章:

python - Cassandra ResultSet 遍历一次后就变空了

numpy - 如何让 NumPy 在 Ubuntu 中使用 OpenBlas?

python - Huggingface Transformer - GPT2 从保存的检查点恢复训练

python - 调用 Detectron2LayoutModel 时出现 OSError

python - 我的 tensorflow keras 模型每次预测总是给出 1.0。它无法正常工作

python - 如何根据多个值的总和删除 Pandas 中的行?

python - 分析和取消wave文件的常用频率

python - 提取 csv 文件特定列以在 Python 中列出

python - "Transform"Numpy 数组 : Move Dimension

python - 如何将字符串列表转换为 pytorch 中的张量?