concurrency - 在分布式 tensorflow 中制作屏障的正确方法是什么?

标签 concurrency tensorflow

在分布式训练期间,我想在每个时期后同步,对首席工作人员进行一些计算,并根据这些计算继续或停止训练。我需要一个障碍才能做到这一点。

我在文档中没有看到任何类似的内容,因此我实现了基于队列的解决方案(类似于分布式训练中梯度的存储和应用方式):

def build_barrier(tasks, task_index, barrier_name):
    queues = []
    for i, task in enumerate(tasks):
        with tf.device('%s/cpu:0' % task):
            with tf.name_scope(barrier_name):
                queues.append(
                    tf.FIFOQueue(
                        len(tasks),
                        (tf.float32),
                        shapes=(()),
                        name=str(i),
                        shared_name=str(i)))

    with tf.control_dependencies([queue.enqueue(1.) for queue in queues]):
        return queues[task_index].dequeue_many(len(tasks))

这个想法是为每个工作人员创建一个队列。对于“信号”,我在每个队列中推送一个 token ,对于“加入”,我从相应的队列中取出如此多的 token ,我想同步多少个任务。

问题是:这是正确的方法还是有更好的方法?

最佳答案

您的解决方案与 SyncReplicasOptimizer 非常相似。在SyncReplicasOptimizer中,它使用同步 token 队列来模拟屏障,并为每个变量使用累加器来累积并平均梯度更新。这是一种非常典型的批量同步并行,同时它还有在 Tensorflow 中实现陈旧同步并行的额外工作。

此外,Tensorflow还提供了Barrier最新版本,您可以查看更多信息。

关于concurrency - 在分布式 tensorflow 中制作屏障的正确方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39638468/

相关文章:

go - 同时从两个 channel 消费会导致 goroutine 占用我的 RAM

python-3.x - Unicode解码错误: 'utf-8' codec can't decode byte 0xea in position 23: invalid continuation byte

python - 模块 'tensorflow' 没有带有股票预测的属性 'get_default_graph'

Python 无法创建 NewWriteableFile(tensorflow.python.framework.errors_impl.NotFoundError : Failed to create a NewWriteableFile: )

python - Tensorflow:训练时如何将模型保存在内存中

python - TensorFlow 机器学习 - 对用户输入进行预测

git - 如果第二次推送只有第一次推送的快进,并发 git 推送是否总是安全的?

c++ - 基于 GPU 的 N^2 比较

go - 如何在并发环境中为每个关键参数运行一次 golang 函数调用?

java - 在游戏循环中等待 FutureTask 完成