python - Tensorflow 变量未使用图间复制进行初始化

标签 python python-3.x tensorflow distributed

我有如下 Python 代码 test.py,它使用分布式 Tensorflow 的“Between-graph Replication”:

import argparse
import logging

import tensorflow as tf

log = logging.getLogger(__name__)

# Job Names
PARAMETER_SERVER = "ps"
WORKER_SERVER = "worker"

# Cluster Details
CLUSTER_SPEC = {
    PARAMETER_SERVER: ["localhost:2222"],
    WORKER_SERVER: ["localhost:1111", "localhost:1112"]}


def parse_command_arguments():
    """ Set up and parse the command line arguments passed for experiment. """
    parser = argparse.ArgumentParser(
        description="Parameters and Arguments for the Test.")
    parser.add_argument(
        "--job_name",
        type=str,
        default="",
        help="One of 'ps', 'worker'"
    )
    # Flags for defining the tf.train.Server
    parser.add_argument(
        "--task_index",
        type=int,
        default=0,
        help="Index of task within the job"
    )

    return parser.parse_args()


def start_server(job_name, task_index):
    """ Create a server based on a cluster spec. """
    cluster = tf.train.ClusterSpec(CLUSTER_SPEC)
    server = tf.train.Server(
        cluster, job_name=job_name, task_index=task_index)

    return server, cluster


def model():
    """ Build up a simple estimator model. """
    # Build a linear model and predict values
    W = tf.Variable([.3], tf.float32)
    b = tf.Variable([-.3], tf.float32)
    x = tf.placeholder(tf.float32)
    linear_model = W * x + b
    y = tf.placeholder(tf.float32)
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)

    # Loss sub-graph
    loss = tf.reduce_sum(tf.square(linear_model - y))

    # optimizer
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    train = optimizer.minimize(loss, global_step=global_step)

    init_op = tf.global_variables_initializer()
    log.info("Variables initialized ...")

    return W, b, loss, x, y, train, global_step, init_op


if __name__ == "__main__":
    # Initializing logging with level "INFO".
    logging.basicConfig(level=logging.INFO)

    # Parse arguments from command line.
    arguments = parse_command_arguments()
    job_name = arguments.job_name
    task_index = arguments.task_index

    # Start a server.
    server, cluster = start_server(job_name, task_index)

    if job_name == "ps":
        server.join()
    else:
        with tf.device(tf.train.replica_device_setter(
                worker_device="/job:worker/task:%d" % task_index,
                cluster=cluster)):
            W, b, loss, x, y, train, global_step, init_op = model()
        with tf.train.MonitoredTrainingSession(
                master=server.target,
                is_chief=(arguments.task_index == 0 and (
                            arguments.job_name == 'worker'))) as sess:
            step = 0
            # training data
            x_train = [1, 2, 3, 4]
            y_train = [0, -1, -2, -3]
            while not sess.should_stop() and step < 1000:
                _, step = sess.run(
                    [train, global_step], {x: x_train, y: y_train})

            # evaluate training accuracy
            curr_W, curr_b, curr_loss = sess.run(
                [W, b, loss], {x: x_train, y: y_train})
            print("W: %s b: %s loss: %s" % (curr_W, curr_b, curr_loss))

我按照以下顺序在一台机器(只有 CPU 的 MacPro)上用 3 个不同的进程运行代码:

  1. 参数服务器:$ python test.py --task_index 0 --job_name ps
  2. worker 1:$ python test.py --task_index 0 --job_name worker
  3. worker 2:$ python test.py --task_index 1 --job_name worker

而且我发现“Worker 2”的进程出错了:

$ python test.py --task_index 1 --job_name worker
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:197] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:197] Initialize GrpcChannelCache for job worker -> {0 -> localhost:1111, 1 -> localhost:1112}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:211] Started server with target: grpc://localhost:1112
INFO:__main__:Variables initialized ...
I tensorflow/core/distributed_runtime/master_session.cc:993] Start master session 9912c75f2921fe13 with config: 

INFO:tensorflow:Waiting for model to be ready.  Ready_for_local_init_op:  None, ready: Variables not initialized: Variable, Variable_1, global_step
INFO:tensorflow:Waiting for model to be ready.  Ready_for_local_init_op:  None, ready: Variables not initialized: Variable, Variable_1, global_step

“ worker 2”的进程就被卡住在那里。该错误显示“Worker 2”的 Tensorflow 变量未成功初始化,所以我想知道 MonitoredTrainingSession 在跨 Tensorflow Session 或其他地方协调变量初始化方面是否存在错误,或者我错过了我的东西代码。

注意:代码使用 Tensorflow 0.12 运行

最佳答案

我认为这是 tf.train.MonitoredTrainingSession 协调协议(protocol)的“预期行为”。在 recent answer ,我解释了该协议(protocol)如何适用于长时间运行的训练作业,因此工作人员将在检查变量是否已初始化之间休眠 30 秒。

运行初始化操作的 Worker 1 和检查变量的 Worker 2 之间存在竞争条件,如果 Worker 2 “赢得”比赛,它会观察到一些变量未初始化,它会进入 30 秒 sleep 在再次检查之前。

但是,你的程序整体的计算量很小,所以在这30秒的时间内,Worker 1可以完成它的工作并终止。当 Worker 2 检查变量是否已初始化时,它将创建一个新的 tf.Session 尝试连接到其他任务,但 Worker 1 不再运行,因此您将看到一条日志像这样的消息(每 10 秒左右重复一次):

I tensorflow/core/distributed_runtime/master.cc:193] CreateSession still waiting for response from worker: /job:worker/replica:0/task:0

当训练作业的时间大大超过 30 秒时,这将不是问题。

一种解决方法是通过设置“设备过滤器”来消除工作人员之间的相互依赖性。由于在典型的图间配置中,各个工作人员不进行通信,因此您可以告诉 TensorFlow 在 session 创建时忽略另一个工作人员的缺席,使用 tf.配置原型(prototype):

# Each worker only needs to contact the PS task(s) and the local worker task.
config = tf.ConfigProto(device_filters=[
    '/job:ps', '/job:worker/task:%d' % arguments.task_index])

with tf.train.MonitoredTrainingSession(
    master=server.target,
    config=config,
    is_chief=(arguments.task_index == 0 and (
              arguments.job_name == 'worker'))) as sess:
  # ...

关于python - Tensorflow 变量未使用图间复制进行初始化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43084960/

相关文章:

python - 如果一个数据帧的内容存在于使用 pandas 的另一个数据帧中,则减去/添加现有值

python - exec 和 eval 如何将 __builtins__ 添加到给定环境中?

tensorflow - 全卷积网络的每像素 softmax

python - 使用 tf.eager 训练复杂的 nn 模型(使用 TF2 符号支持效果更好)

python - 如何编写 url 模式。 python , Django

python - 用小数秒将 excel 时间导入 Pandas

python - 如何测试抽象工厂

python - 德摩根定律是 Pythonic 吗?

python 3.4 : Converting ushort to bytes

python - tensorflow 卷积