multithreading - 在Vec上实现并行/多线程合并排序

标签 multithreading rust

我正在尝试通过实现并行合并排序来学习Rust的多线程。一个简单的递归版本就可以了,但是这个版本:

use rand;

use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
    //let mut input_line = String::new();
    // println!("Input amount of numbers to sort:");
    // let amount = match std::io::stdin().read_line(&mut input_line){
    //     Ok(_) => i64::from_str_radix(&input_line.trim(), 10).unwrap(),
    //     Err(_) => panic!("Error while reading amount of values")
    // };
    let amount = 1_000_000;

    // let mut rnd = rand::thread_rng();
    let mut arr: Vec<i64> = Vec::new();
    for _ in 0..amount {
        arr.push(rand::random::<i64>())
    }

    // println!("Vector before sort:");
    // for elem in &arr {
    //     println!("{}", elem);
    // }

    merge_sort(&mut arr);

    // println!("Vector after sort:");
    // for elem in &arr {
    //     println!("{}", elem);
    // }
}

fn merge_sort(arr: &mut Vec<i64>) {
    let arr_len = arr.len();
    let arr_slice = arr.as_mut_slice();

    // simple_merge_sort(arr, 0 as usize, arr_len - 1 as usize);

    let arc = Arc::new(Mutex::new(arr));
    par_merge_sort(&mut arc, 0 as usize, arr_len - 1 as usize, 4);
}

fn simple_merge_sort(arr: &mut Vec<i64>, lo: usize, hi: usize) {
    if lo == hi {
        return;
    }

    let mi = (hi + lo) / 2;
    simple_merge_sort(arr, lo, mi);
    simple_merge_sort(arr, mi + 1, hi);

    merge(arr, lo, mi, hi);
}

fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
    if lo == hi {
        return;
    }

    let mi = (hi + lo) / 2_usize;
    if threads == 1 {
        let mut simple_arr = arc.lock().unwrap();
        simple_merge_sort(&mut simple_arr, lo, hi);
    } else {
        let thread_arc = Arc::from(*arc);
        let thread_rest = threads / 2;
        let thread_rest_2 = threads - thread_rest;
        let thread1 = thread::spawn(move || {
            par_merge_sort(&mut thread_arc, lo, mi, thread_rest);
        });
        let thread_arc = Arc::from(*arc);
        let thread2 = thread::spawn(move || {
            par_merge_sort(&mut thread_arc, mi + 1, hi, thread_rest_2);
        });

        thread1.join().unwrap();
        thread2.join().unwrap();
    }

    let mutex = arc.lock().unwrap();
    merge(&mut *mutex, lo, mi, hi);
}

fn merge(arr: &mut Vec<i64>, lo: usize, mi: usize, hi: usize) {
    let mut lo_arr: Vec<i64> = Vec::new();
    for i in lo..(mi + 1) {
        let elem = *arr.get(i).unwrap();
        lo_arr.push(elem);
    }

    let mut hi_arr: Vec<i64> = Vec::new();
    for i in (mi + 1)..(hi + 1) {
        let elem = *arr.get(i).unwrap();
        hi_arr.push(elem);
    }

    let mut i = 0;
    let mut j = 0;
    let mut counter = lo;

    while i < lo_arr.len() && j < hi_arr.len() {
        let elem_i = *lo_arr.get(i).unwrap();
        let elem_j = *hi_arr.get(j).unwrap();

        if elem_i <= elem_j {
            arr[counter] = elem_i;
            i += 1;
        } else {
            // elem_j <= elem_i
            arr[counter] = elem_j;
            j += 1;
        }
        counter += 1;
    }

    if j == hi_arr.len() {
        while i < lo_arr.len() {
            let elem_i = *lo_arr.get(i).unwrap();
            arr[counter] = elem_i;
            i += 1;
            counter += 1;
        }
    } else {
        // i == lo_arr.len()
        while j < hi_arr.len() {
            let elem_j = *hi_arr.get(j).unwrap();
            arr[counter] = elem_j;
            j += 1;
            counter += 1;
        }
    }
}
产生错误:
error[E0621]: explicit lifetime required in the type of `arc`
  --> src/main.rs:69:23
   |
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
   |                        ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
69 |         let thread1 = thread::spawn(move || {
   |                       ^^^^^^^^^^^^^ lifetime `'static` required

error[E0621]: explicit lifetime required in the type of `arc`
  --> src/main.rs:73:23
   |
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
   |                        ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
73 |         let thread2 = thread::spawn(move || {
   |                       ^^^^^^^^^^^^^ lifetime `'static` required

最佳答案

由于您的问题是关于并行化而不是排序,因此我在下面的示例中省略了serial_sortmerge函数的实现,但是您可以使用已经拥有的代码轻松地将它们填充到自己中:

#![feature(is_sorted)]

use crossbeam; // 0.8.0
use rand; // 0.7.3
use rand::Rng;

fn random_vec(capacity: usize) -> Vec<i64> {
    let mut vec = vec![0; capacity];
    rand::thread_rng().fill(&mut vec[..]);
    vec
}

fn parallel_sort(data: &mut [i64], threads: usize) {
    let chunks = std::cmp::min(data.len(), threads);
    let _ = crossbeam::scope(|scope| {
        for slice in data.chunks_mut(data.len() / chunks) {
            scope.spawn(move |_| serial_sort(slice));
        }
    });
    merge(data, chunks);
}

fn serial_sort(data: &mut [i64]) {
    // actual implementation omitted for conciseness
    data.sort()
}

fn merge(data: &mut [i64], _sorted_chunks: usize) {
    // actual implementation omitted for conciseness
    data.sort()
}

fn main() {
    let mut vec = random_vec(10_000);
    parallel_sort(&mut vec, 4);
    assert!(vec.is_sorted());
}
playgroundparallel_sort将数据分解为n块,并在其自己的线程中对每个块进行排序,而merge将排序后的块放在一起,最后返回。

关于multithreading - 在Vec上实现并行/多线程合并排序,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65415293/

相关文章:

rust - 如何为函数实现特征

Java 多线程服务器 *有时* 在 ServerSocket.accept() 方法中抛出 SocketException(套接字关闭)

python - 在多线程生产者-消费者模式下,如何让工作线程在工作完成后退出?

java - 即使在同步方法中访问 int,int 也会递增两次

rust - 如何阅读基于Tokio的Hyper请求的整个正文?

rust - f32 没有实现减法?

rust - 多个 rust 文件需要使用相同的结构和函数

rust - 如何根据 Rocket 服务器的状态响应不同的值?

java - 同步线程等待多个线程

java - 迭代列表并将 Callables 提交到 ExecutorService