go - 如何并行化递归函数

标签 go

我正在尝试在 Go 中并行化一个递归问题,但我不确定最好的方法是什么。

我有一个递归函数,它是这样工作的:

func recFunc(input string) (result []string) {
    for subInput := range getSubInputs(input) {
        subOutput := recFunc(subInput)
        result = result.append(result, subOutput...)
    }
    result = result.append(result, getOutput(input)...)
}

func main() {
    output := recFunc("some_input")
    ...
}

因此函数调用自身 N 次(其中 N 在某个级别为 0),生成自己的输出并返回列表中的所有内容。

现在我想让这个函数并行运行。但我不确定最干净的方法是什么。我的想法:

  • 有一个“结果” channel ,所有函数调用都将结果发送到该 channel 。
  • 在主函数中收集结果。
  • 有一个 WaitGroup ,它决定何时收集所有结果。

问题:我需要等待 WaitGroup 并并行收集所有结果。我可以为此启动一个单独的 go 函数,但是我该如何退出这个单独的 go 函数?

func recFunc(input string) (result []string, outputChannel chan []string, waitGroup &sync.WaitGroup) {
    defer waitGroup.Done()
    waitGroup.Add(len(getSubInputs(input))
    for subInput := range getSubInputs(input) {
        go recFunc(subInput)
    }
    outputChannel <-getOutput(input)
}

func main() {
    outputChannel := make(chan []string)
    waitGroup := sync.WaitGroup{}

    waitGroup.Add(1)
    go recFunc("some_input", outputChannel, &waitGroup)

    result := []string{}
    go func() {
       nextResult := <- outputChannel
       result = append(result, nextResult ...)
    }
    waitGroup.Wait()
}

也许有更好的方法来做到这一点?或者我如何确保收集结果的匿名 go 函数在完成后退出?

最佳答案

tl;博士;

  • 递归算法应该对昂贵的资源(网络连接、goroutines、堆栈空间等)有限制
  • 应支持取消 - 以确保在不再需要某个结果时可以快速清理昂贵的操作
  • 分支遍历应该支持错误报告;这允许错误在堆栈中冒泡并返回部分结果,而不会导致整个递归遍历失败。

对于异步结果 - 无论是否使用递归 - 建议使用 channel 。此外,对于具有许多 goroutine 的长时间运行的作业,提供一种取消方法 ( context.Context ) 以帮助清理。

由于递归会导致资源的指数级消耗,因此设置限制很重要(请参阅 bounded parallelism)。

下面是我在异步任务中经常使用的设计模式:

  • 始终支持context.Context取消
  • 任务所需的 worker 数量
  • 返回一个chan结果和一个chan错误(将只返回一个错误或nil)

var (
    workers = 10
    ctx     = context.TODO() // use request context here - otherwise context.Background()
    input   = "abc"
)

resultC, errC := recJob(ctx, workers, input) // returns results & `error` channels

// asynchronous results - so read that channel first in the event of partial results ...
for r := range resultC {
    fmt.Println(r)
}

// ... then check for any errors
if err := <-errC; err != nil {
    log.Fatal(err)
}

递归:

由于递归可以快速横向扩展,因此需要一种一致的方式来用工作填充有限的工作人员列表,同时还要确保工作人员有空时,他们会迅速从其他(过度工作的)工作人员那里接手工作。

与其创建一个管理层,不如使用一个协作的 worker 对等系统:

  • 每个工作人员共享一个输入 channel
  • 在递归输入 (subIinputs) 之前检查是否有其他 worker 空闲
    • 如果是,委托(delegate)给那个 worker
    • 如果不是,则当前工作人员继续递归该分支

使用此算法,有限数量的工作人员很快就会因工作而饱和。任何提前完成其分支机构的工作人员 - 将很快被另一名工作人员委派一个子分支机构。最终所有 worker 都将用完子分支,此时所有 worker 都将空闲(阻塞)并且递归任务可以完成。

要实现这一目标,需要进行一些仔细的协调。允许工作人员写入到输入 channel 有助于通过委托(delegate)进行这种对等协调。 “递归深度”WaitGroup 用于跟踪所有工作人员的所有分支何时耗尽。

(为了包括上下文支持和错误链接 - 我更新了您的 getSubInputs 函数以获取 ctx 并返回可选的 error):

func recFunc(ctx context.Context, input string, in chan string, out chan<- string, rwg *sync.WaitGroup) error {

    defer rwg.Done() // decrement recursion count when a depth of recursion has completed

    subInputs, err := getSubInputs(ctx, input)
    if err != nil {
        return err
    }

    for subInput := range subInputs { 
        rwg.Add(1) // about to recurse (or delegate recursion)

        select {
        case in <- subInput:
            // delegated - to another goroutine

        case <-ctx.Done():
            // context canceled...

            // but first we need to undo the earlier `rwg.Add(1)`
            // as this work item was never delegated or handled by this worker
            rwg.Done()
            return ctx.Err()

        default:
            // noone available to delegate - so this worker will need to recurse this item themselves
            err = recFunc(ctx, subInput, in, out, rwg)
            if err != nil {
                return err
            }
        }

        select {
        case <-ctx.Done():
            // always check context when doing anything potentially blocking (in this case writing to `out`)
            // context canceled
            return ctx.Err()

        case out <- subInput:
        }
    }

    return nil
}

连接件:

recJob 创建:

  • 输入和输出 channel - 由所有 worker 共享
  • “递归” WaitGroup 检测所有工作人员何时空闲
    • 然后可以安全地关闭“输出” channel
  • 所有 worker 的错误 channel
  • 通过将初始输入写入输入 channel 来启动递归工作负载

func recJob(ctx context.Context, workers int, input string) (resultsC <-chan string, errC <-chan error) {

    // RW channels
    out := make(chan string)
    eC := make(chan error, 1)

    // R-only channels returned to caller
    resultsC, errC = out, eC

    // create workers + waitgroup logic
    go func() {

        var err error // error that will be returned to call via error channel

        defer func() {
            close(out)
            eC <- err
            close(eC)
        }()

        var wg sync.WaitGroup
        wg.Add(1)
        in := make(chan string) // input channel: shared by all workers (to read from and also to write to when they need to delegate)

        workerErrC := createWorkers(ctx, workers, in, out, &wg)

        // get the ball rolling, pass input job to one of the workers
        // Note: must be done *after* workers are created - otherwise deadlock
        in <- input

        errCount := 0

        // wait for all worker error codes to return
        for err2 := range workerErrC {
            if err2 != nil {
                log.Println("worker error:", err2)
                errCount++
            }
        }

        // all workers have completed
        if errCount > 0 {
            err = fmt.Errorf("PARTIAL RESULT: %d of %d workers encountered errors", errCount, workers)
            return
        }

        log.Printf("All %d workers have FINISHED\n", workers)
    }()

    return
}

最后,创建 worker :

func createWorkers(ctx context.Context, workers int, in chan string, out chan<- string, rwg *sync.WaitGroup) (errC <-chan error) {

    eC := make(chan error) // RW-version
    errC = eC              // RO-version (returned to caller)

    // track the completeness of the workers - so we know when to wrap up
    var wg sync.WaitGroup
    wg.Add(workers)

    for i := 0; i < workers; i++ {
        i := i
        go func() {
            defer wg.Done()

            var err error

            // ensure the current worker's return code gets returned
            // via the common workers' error-channel
            defer func() {
                if err != nil {
                    log.Printf("worker #%3d ERRORED: %s\n", i+1, err)
                } else {
                    log.Printf("worker #%3d FINISHED.\n", i+1)
                }
                eC <- err
            }()

            log.Printf("worker #%3d STARTED successfully\n", i+1)

            // worker scans for input
            for input := range in {

                err = recFunc(ctx, input, in, out, rwg)
                if err != nil {
                    log.Printf("worker #%3d recurseManagers ERROR: %s\n", i+1, err)
                    return
                }
            }

        }()
    }

    go func() {
        rwg.Wait() // wait for all recursion to finish
        close(in)  // safe to close input channel as all workers are blocked (i.e. no new inputs)
        wg.Wait()  // now wait for all workers to return
        close(eC)  // finally, signal to caller we're truly done by closing workers' error-channel
    }()

    return
}

关于go - 如何并行化递归函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65548405/

相关文章:

logging - 使用 systemd 在 golang 中旋转日志文件

go - 未定义 : function (declared in another package)

go - 从保险库 KV 值构建动态字符串

java - ECDSA 签名 Java vs Go

image - Go Code 在 go test 和 go run 中的行为不同

parsing - 在 Go 中并发解析二进制文件中的记录

mysql - 去和 mysql 和 cloneTLSConfig

docker - 通信运行 golang 的多个容器

multithreading - Go 中的简单工作池

go - Go lang GORM中的 “sql:” -“”代表什么?