我正在尝试使用ndarray计算点积,并且遇到了我不理解的编译错误。
我的基本功能是
use ndarray::{ArrayD, ArrayBase};
pub fn cosine<L>(v1: &ArrayBase<f64, L>, v2: &ArrayBase<f64, L>) -> f64 {
let x: f64 = v1.dot(&v2) / (v1.dot(v1) * v2.dot(v2)).sqrt();
return x
}
pub fn cosine2(v1: &ArrayD<f64>, v2: &ArrayD<f64>) -> f64 {
let x: f64 = v1.dot(v2) / (v1.dot(v1) * v2.dot(v2)).sqrt();
return x
}
无法编译:error[E0277]: the trait bound `f64: ndarray::data_traits::RawData` is not satisfiedchgraph
--> src/simple.rs:3:1
|
3 | / pub fn cosine<L>(v1: &ArrayBase<f64, L>, v2: &ArrayBase<f64, L>) -> f64 {
4 | | let x: f64 = v1.dot(&v2) / (v1.dot(v1) * v2.dot(v2)).sqrt();
5 | | }
| |_^ the trait `ndarray::data_traits::RawData` is not implemented for `f64`
|
= note: required by `ndarray::ArrayBase`
如果我注释掉cosine
,我从cosine2
收到一个错误:error[E0599]: no method named `dot` found for reference `&ndarray::ArrayBase<ndarray::data_repr::OwnedRepr<f64>, ndarray::dimension::dim::Dim<ndarray::dimension::dynindeximpl::IxDynImpl>>` in the current scope
--> src/simple.rs:9:21
|
9 | let x: f64 = v1.dot(v2) / (v1.dot(v1) * v2.dot(v2)).sqrt();
| ^^^ method not found in `&ndarray::ArrayBase<ndarray::data_repr::OwnedRepr<f64>, ndarray::dimension::dim::Dim<ndarray::dimension::dynindeximpl::IxDynImpl>>`
(以及其他点产品的另外两个副本)。为什么第二个版本找不到该方法?看来ArrayD
是基于Array
的类型,而后者又是基于ArrayBase
的类型,因此ArrayD::dot
应该是现有方法。我只需要能够传递
ArrayD
,因此我对任何一种都能使用的版本感到满意。我的Cargo.toml的相关部分是
[dependencies.ndarray]
version = "0.13.1"
[features]
default = ["ndarray/blas"]
最佳答案
首先,ArrayBase
的数据类型不是由数据类型索引的,而是由数据类型的RawData
包装器索引的。第二,dot
需要实现Dot
特性。因此,您应该将这两个都添加到特征范围:
use ndarray::linalg::Dot;
use ndarray::{ArrayBase, ArrayD, RawData};
pub fn cosine<D, L>(v1: &ArrayBase<D, L>, v2: &ArrayBase<D, L>) -> f64
where
D: RawData<Elem = f64>,
ArrayBase<D, L>: Dot<ArrayBase<D, L>, Output = f64>,
{
let x: f64 = v1.dot(&v2) / (v1.dot(v1) * v2.dot(v2)).sqrt();
return x;
}
关于rust - 未为f64实现RawData,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64774226/