python - 使用 tensorflow-gpu 1.14 和 tf.distribute.MirroredStrategy() 的自定义训练循环导致 ValueError

标签 python tensorflow keras distributed

我正在尝试使用 tf.distribute.MirroredStrategy() 在多个 GPU 上运行自定义训练循环.虽然训练循环在单个 GPU 上完美运行,但 ValueError: 'handle' is not available outside the replica context or a 'tf.distribute.Strategy.update()' call当我尝试使用多个 GPU 时抛出。我正在使用 tensorflow 1.14 和 Python 3.7.3。

我在下面尝试了一个最小的例子。自定义训练循环在单个 GPU 上运行没有问题,但我尝试使用 tf.distribute.MirroredStrategy()对于多个 GPU 失败并显示错误消息(完整输出)

ValueError                                Traceback (most recent call last)
<ipython-input-11-3fda5d330457> in <module>
      1 with mirrored_strategy.scope():
----> 2     model, train_op, X1_in, X2_in = create_model_and_train_op()
      3     with tf.Session() as sess:
      4         sess.run(tf.global_variables_initializer())
      5         for sample_ind in range(n_samples):

<ipython-input-7-8f5b3971bbe2> in create_model_and_train_op()
      6 
      7     model = Model(name='BNN',inputs=[X1_in,X2_in], outputs=[loss])
----> 8     train_op = tf.train.AdamOptimizer().minimize(loss)
      9 
     10     return model, train_op, X1_in, X2_in

~/.local/share/virtualenvs/keras_bnn_lv-NP1oBJBi/lib/python3.7/site-packages/tensorflow/python/training/optimizer.py in minimize(self, loss, global_step, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, name, grad_loss)
    401         aggregation_method=aggregation_method,
    402         colocate_gradients_with_ops=colocate_gradients_with_ops,
--> 403         grad_loss=grad_loss)
    404 
    405     vars_with_grad = [v for g, v in grads_and_vars if g is not None]

~/.local/share/virtualenvs/keras_bnn_lv-NP1oBJBi/lib/python3.7/site-packages/tensorflow/python/training/optimizer.py in compute_gradients(self, loss, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, grad_loss)
    510         gate_gradients=(gate_gradients == Optimizer.GATE_OP),
    511         aggregation_method=aggregation_method,
--> 512         colocate_gradients_with_ops=colocate_gradients_with_ops)
    513     if gate_gradients == Optimizer.GATE_GRAPH:
    514       grads = control_flow_ops.tuple(grads)

~/.local/share/virtualenvs/keras_bnn_lv-NP1oBJBi/lib/python3.7/site-packages/tensorflow/python/ops/gradients_impl.py in gradients(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients)
    156         ys, xs, grad_ys, name, colocate_gradients_with_ops,
    157         gate_gradients, aggregation_method, stop_gradients,
--> 158         unconnected_gradients)
    159   # pylint: enable=protected-access
    160 

~/.local/share/virtualenvs/keras_bnn_lv-NP1oBJBi/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    595     xs = [
    596         x.handle if resource_variable_ops.is_resource_variable(x) else x
--> 597         for x in xs
    598     ]
    599     xs = ops.internal_convert_n_to_tensor_or_indexed_slices(

~/.local/share/virtualenvs/keras_bnn_lv-NP1oBJBi/lib/python3.7/site-packages/tensorflow/python/ops/gradients_util.py in <listcomp>(.0)
    595     xs = [
    596         x.handle if resource_variable_ops.is_resource_variable(x) else x
--> 597         for x in xs
    598     ]
    599     xs = ops.internal_convert_n_to_tensor_or_indexed_slices(

~/.local/share/virtualenvs/keras_bnn_lv-NP1oBJBi/lib/python3.7/site-packages/tensorflow/python/distribute/values.py in handle(self)
    641       device = distribute_lib.get_update_device()
    642       if device is None:
--> 643         raise ValueError("`handle` is not available outside the replica context"
    644                          " or a `tf.distribute.Strategy.update()` call.")
    645     return self.get(device=device).handle

ValueError: `handle` is not available outside the replica context or a `tf.distribute.Strategy.update()` call.

Google 建议的唯一修复是更新到 tensorflow 2.0.0 - 测试版。我想知道是否有办法在 1.14 中解决这个问题。

这是我尝试过的最小示例:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Concatenate
from tensorflow.keras.models import Model

import sys
print (sys.version)
print(tf.__version__)

input_dim = 42
n_samples = 10000

x1_data = np.random.rand(n_samples,input_dim)
x2_data = np.random.rand(n_samples,input_dim)

def create_model_and_train_op():
    X1_in = Input(shape=(input_dim,))
    X2_in = Input(shape=(input_dim,))
    XY = Concatenate(axis=-1)([X1_in,X2_in])
    loss = Dense(1)(XY)
    model = Model(name='BNN',inputs=[X1_in,X2_in], outputs=[loss])

    # Error message is thrown in the following line if using MirroredStrategy()
    train_op = tf.train.AdamOptimizer().minimize(loss)

    return model, train_op, X1_in, X2_in


##### Single GPU: Runs without problems
model, train_op, X1_in, X2_in = create_model_and_train_op()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for sample_ind in range(n_samples):
        sess.run(train_op, feed_dict = {X1_in : x1_data[sample_ind].reshape(1,input_dim) , X2_in : x2_data[sample_ind].reshape(1,input_dim) })


##### Multiple GPU: Results in error message
mirrored_strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(mirrored_strategy.num_replicas_in_sync))

with mirrored_strategy.scope():
    model, train_op, X1_in, X2_in = create_model_and_train_op()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for sample_ind in range(n_samples):
            sess.run(train_op, feed_dict = {X1_in : x1_data[sample_ind].reshape(1,input_dim) , X2_in : x2_data[sample_ind].reshape(1,input_dim) })

最佳答案

我仅使用权重负载解决了这个问题。
这是来自 Keras manual 的修改示例关于多 GPU 训练。

import tensorflow as tf
from tensorflow import keras
import os 
from tensorflow.python.keras.backend import set_session

def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.

    global sess
    global graph
    sess = tf.Session()
    graph = tf.get_default_graph()
    # a special trick from here https://github.com/tensorflow/tensorflow/issues/28287

    # IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras! 
    # Otherwise, their weights will be unavailable in the threads after the session there has been set
    set_session(sess)
    inputs = keras.Input(shape=(784,))
    x = keras.layers.Dense(256, activation="relu")(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
    if checkpoints:
        checkpoints.sort()
        latest_checkpoint = checkpoints[-1]
        latest_checkpoint = checkpoint_dir + "/" + os.path.splitext(os.path.split(checkpoints[-1])[-1])[0]
        print("Restoring from", latest_checkpoint)
        model.load_weights(latest_checkpoint)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_dataset():
    batch_size = 32
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a `tf.data.Dataset`.
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )



# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()


# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)


def make_or_restore_model():
    # Either restore the latest model, or create a fresh one
    # if there is no checkpoint available.
    print("Creating a new model")
    return get_compiled_model()


def run_training(epochs=1):
    # Create a MirroredStrategy.
    strategy = tf.distribute.MirroredStrategy()

    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
    # Open a strategy scope and create/restore the model
    with strategy.scope():
        model = make_or_restore_model()

    callbacks = [
        # This callback saves a SavedModel every epoch
        # We include the current epoch in the folder name.
        keras.callbacks.ModelCheckpoint(save_weights_only=True,
            filepath=checkpoint_dir + "/ckpt-{epoch}.cpkt", save_freq="epoch"
        )
    ]
    with graph.as_default():
        set_session(sess)
        model.fit(
            train_dataset,
            epochs=epochs,
            callbacks=callbacks,
            validation_data=val_dataset,
            verbose=2,
        )
    return model


# Running the first time creates the model
model = run_training(epochs=2)

# Test the model on all available devices.
print("Evaluating")
model.evaluate(test_dataset)

run_training(epochs=1)

关于python - 使用 tensorflow-gpu 1.14 和 tf.distribute.MirroredStrategy() 的自定义训练循环导致 ValueError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56836895/

相关文章:

machine-learning - 如何正确输入形状或输入尺寸?

python - 迭代大型数据框的有效方法

tensorflow - Keras/Tensorflow 解决线性回归任务的局限性

json - TensorflowJS : Failed to parse model. json

java - 是否可以编写Java客户端来访问Tensorflow Server?

python - 使用 keras 的 sk-learn API 时出错

python - Keras:如何加载具有两个输出和自定义损失函数的模型?

python - 不和谐 py - 错误 : Unable to extract JS player URL

python - Matplotlib将某些字符绘制为空白正方形

python - Python 中的分位数函数