我正在尝试通过实现并行合并排序来学习Rust的多线程。一个简单的递归版本就可以了,但是这个版本:
use rand;
use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
//let mut input_line = String::new();
// println!("Input amount of numbers to sort:");
// let amount = match std::io::stdin().read_line(&mut input_line){
// Ok(_) => i64::from_str_radix(&input_line.trim(), 10).unwrap(),
// Err(_) => panic!("Error while reading amount of values")
// };
let amount = 1_000_000;
// let mut rnd = rand::thread_rng();
let mut arr: Vec<i64> = Vec::new();
for _ in 0..amount {
arr.push(rand::random::<i64>())
}
// println!("Vector before sort:");
// for elem in &arr {
// println!("{}", elem);
// }
merge_sort(&mut arr);
// println!("Vector after sort:");
// for elem in &arr {
// println!("{}", elem);
// }
}
fn merge_sort(arr: &mut Vec<i64>) {
let arr_len = arr.len();
let arr_slice = arr.as_mut_slice();
// simple_merge_sort(arr, 0 as usize, arr_len - 1 as usize);
let arc = Arc::new(Mutex::new(arr));
par_merge_sort(&mut arc, 0 as usize, arr_len - 1 as usize, 4);
}
fn simple_merge_sort(arr: &mut Vec<i64>, lo: usize, hi: usize) {
if lo == hi {
return;
}
let mi = (hi + lo) / 2;
simple_merge_sort(arr, lo, mi);
simple_merge_sort(arr, mi + 1, hi);
merge(arr, lo, mi, hi);
}
fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
if lo == hi {
return;
}
let mi = (hi + lo) / 2_usize;
if threads == 1 {
let mut simple_arr = arc.lock().unwrap();
simple_merge_sort(&mut simple_arr, lo, hi);
} else {
let thread_arc = Arc::from(*arc);
let thread_rest = threads / 2;
let thread_rest_2 = threads - thread_rest;
let thread1 = thread::spawn(move || {
par_merge_sort(&mut thread_arc, lo, mi, thread_rest);
});
let thread_arc = Arc::from(*arc);
let thread2 = thread::spawn(move || {
par_merge_sort(&mut thread_arc, mi + 1, hi, thread_rest_2);
});
thread1.join().unwrap();
thread2.join().unwrap();
}
let mutex = arc.lock().unwrap();
merge(&mut *mutex, lo, mi, hi);
}
fn merge(arr: &mut Vec<i64>, lo: usize, mi: usize, hi: usize) {
let mut lo_arr: Vec<i64> = Vec::new();
for i in lo..(mi + 1) {
let elem = *arr.get(i).unwrap();
lo_arr.push(elem);
}
let mut hi_arr: Vec<i64> = Vec::new();
for i in (mi + 1)..(hi + 1) {
let elem = *arr.get(i).unwrap();
hi_arr.push(elem);
}
let mut i = 0;
let mut j = 0;
let mut counter = lo;
while i < lo_arr.len() && j < hi_arr.len() {
let elem_i = *lo_arr.get(i).unwrap();
let elem_j = *hi_arr.get(j).unwrap();
if elem_i <= elem_j {
arr[counter] = elem_i;
i += 1;
} else {
// elem_j <= elem_i
arr[counter] = elem_j;
j += 1;
}
counter += 1;
}
if j == hi_arr.len() {
while i < lo_arr.len() {
let elem_i = *lo_arr.get(i).unwrap();
arr[counter] = elem_i;
i += 1;
counter += 1;
}
} else {
// i == lo_arr.len()
while j < hi_arr.len() {
let elem_j = *hi_arr.get(j).unwrap();
arr[counter] = elem_j;
j += 1;
counter += 1;
}
}
}
产生错误:error[E0621]: explicit lifetime required in the type of `arc`
--> src/main.rs:69:23
|
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
| ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
69 | let thread1 = thread::spawn(move || {
| ^^^^^^^^^^^^^ lifetime `'static` required
error[E0621]: explicit lifetime required in the type of `arc`
--> src/main.rs:73:23
|
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
| ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
73 | let thread2 = thread::spawn(move || {
| ^^^^^^^^^^^^^ lifetime `'static` required
最佳答案
由于您的问题是关于并行化而不是排序,因此我在下面的示例中省略了serial_sort
和merge
函数的实现,但是您可以使用已经拥有的代码轻松地将它们填充到自己中:
#![feature(is_sorted)]
use crossbeam; // 0.8.0
use rand; // 0.7.3
use rand::Rng;
fn random_vec(capacity: usize) -> Vec<i64> {
let mut vec = vec![0; capacity];
rand::thread_rng().fill(&mut vec[..]);
vec
}
fn parallel_sort(data: &mut [i64], threads: usize) {
let chunks = std::cmp::min(data.len(), threads);
let _ = crossbeam::scope(|scope| {
for slice in data.chunks_mut(data.len() / chunks) {
scope.spawn(move |_| serial_sort(slice));
}
});
merge(data, chunks);
}
fn serial_sort(data: &mut [i64]) {
// actual implementation omitted for conciseness
data.sort()
}
fn merge(data: &mut [i64], _sorted_chunks: usize) {
// actual implementation omitted for conciseness
data.sort()
}
fn main() {
let mut vec = random_vec(10_000);
parallel_sort(&mut vec, 4);
assert!(vec.is_sorted());
}
playgroundparallel_sort
将数据分解为n
块,并在其自己的线程中对每个块进行排序,而merge
将排序后的块放在一起,最后返回。
关于multithreading - 在Vec上实现并行/多线程合并排序,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65415293/