我有一个 f32
类型的二维数组(来自 ndarray::ArrayView2
),我想找到每一行中最大值的索引,并将索引值到另一个数组。
Python 中的等价物是这样的:
import numpy as np
for i in range (0, max_val, batch_size):
sims = xp.dot(batch, vectors.T)
# sims is the dot product of batch and vectors.T
# the shape is, for example, (1024, 10000)
best_rows[i: i+batch_size] = sims.argmax(axis = 1)
在 Python 中,函数 .argmax
非常快,但我在 Rust 中没有看到类似的函数。最快的方法是什么?
最佳答案
考虑一般 Ord
类型的简单情况:答案会略有不同,具体取决于您是否知道值是 Copy
,但这是代码:
fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}
fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}
基本思想是我们将数组中的每个项目(实际上,一个切片——无论它是 Vec 还是数组或更奇特的东西都没有关系)与其索引配对,使用 std::iter::Iterator
函数只根据值(不是索引)找到最大值,然后只返回索引。如果切片为空 None
将被返回。根据文档,将返回最右边的索引;如果您需要最左边的,请执行 rev()
after enumerate()
。
rev()
、enumerate()
、max_by_key()
和 max_by()
记录在 here ; slice::iter()
被记录为 here(但作为 rust 开发者,在没有文档的情况下,它需要出现在你要记忆的 list 上); map
是 Option::map()
记录的 here(同上)。哦,cmp
是 Ord::cmp
但大多数时候您可以使用不需要它的 Copy
版本(例如,如果你在比较整数)。
现在问题来了:f32
不是 Ord
因为 IEEE float 的工作方式。大多数语言都忽略了这一点并且有微妙的错误算法。在 Ord
上提供总订单的最受欢迎的箱子(通过声明所有 NaN 都相等,并且大于所有数字)似乎是 ordered-float 。假设它实现正确,它应该非常轻量级。它确实引入了 num_traits
,但这是最流行的数字库的一部分,因此很可能已经被其他依赖项引入。
在这种情况下,您可以通过将 ordered_float::OrderedFloat
(元组类型的“构造函数”)映射到切片迭代器(slice.iter().map( ordered_float::OrderedFloat)
).由于您只需要最大元素的位置,因此之后无需提取 f32。
关于arrays - 查找数组中最大值索引的最快方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57813951/