arrays - 查找数组中最大值索引的最快方法是什么?

标签 arrays rust

我有一个 f32 类型的二维数组(来自 ndarray::ArrayView2),我想找到每一行中最大值的索引,并将索引值到另一个数组。

Python 中的等价物是这样的:

import numpy as np

for i in range (0, max_val, batch_size):
   sims = xp.dot(batch, vectors.T) 
   # sims is the dot product of batch and vectors.T
   # the shape is, for example, (1024, 10000)

   best_rows[i: i+batch_size] = sims.argmax(axis = 1)

在 Python 中,函数 .argmax 非常快,但我在 Rust 中没有看到类似的函数。最快的方法是什么?

最佳答案

考虑一般 Ord 类型的简单情况:答案会略有不同,具体取决于您是否知道值是 Copy,但这是代码:

fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}

fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}

基本思想是我们将数组中的每个项目(实际上,一个切片——无论它是 Vec 还是数组或更奇特的东西都没有关系)与其索引配对,使用 std::iter::Iterator 函数只根据值(不是索引)找到最大值,然后只返回索引。如果切片为空 None 将被返回。根据文档,将返回最右边的索引;如果您需要最左边的,请执行 rev() after enumerate()

rev()enumerate()max_by_key()max_by() 记录在 here ; slice::iter() 被记录为 here(但作为 rust 开发者,在没有文档的情况下,它需要出现在你要记忆的 list 上); mapOption::map() 记录的 here(同上)。哦,cmpOrd::cmp 但大多数时候您可以使用不需要它的 Copy 版本(例如,如果你在比较整数)。


现在问题来了:f32 不是 Ord 因为 IEEE float 的工作方式。大多数语言都忽略了这一点并且有微妙的错误算法。在 Ord 上提供总订单的最受欢迎的箱子(通过声明所有 NaN 都相等,并且大于所有数字)似乎是 ordered-float 。假设它实现正确,它应该非常轻量级。它确实引入了 num_traits,但这是最流行的数字库的一部分,因此很可能已经被其他依赖项引入。

在这种情况下,您可以通过将 ordered_float::OrderedFloat(元组类型的“构造函数”)映射到切片迭代器(slice.iter().map( ordered_float::OrderedFloat)).由于您只需要最大元素的位置,因此之后无需提取 f32。

关于arrays - 查找数组中最大值索引的最快方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57813951/

相关文章:

rust - 如何创建一个大小合适的闭包或在结构上实现 Fn/FnMut/FnOnce?

rust - 如何在 Rust 中使用 serde 为容器设置 "deserialize with"

rust - 返回迭代器(或任何其他特征)的正确方法是什么?

c - 在函数中传递二维数组 - 从 C 中的函数返回二维数组

C 中的计数排序 - 错误 : Use of undeclared identifier

使用带有 char 字符串和数组的结构的 C 程序

rust - 将 1 元组结构转换为包含的元素的惯用方法是什么?

javascript - 为什么 NodeJS 在计算素数和方面比 Rust 快?

arrays - 对数组进行排序并检索排序后的索引

PHP - 连接/级联多维数组键