python - 使用 Functional API 和 tf.GradientTape() 的组合在 Tensorflow 2.0 中训练时如何将模型图记录到张量板?

标签 python tensorflow tensorflow2.0

当我使用 Keras Functional API 或模型子调用 API 来创建模型并使用 tf.GradientTape() 来训练模型时,有人可以指导我如何将模型图记录到张量板上吗?

# Get the model.
inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy()


batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Iterate over epochs.
epochs = 3
for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))

    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):   

    with tf.GradientTape() as tape:
        logits = model(x_batch_train)
        loss_value = loss_fn(y_batch_train, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)


    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    if step % 200 == 0:
        print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
        print('Seen so far: %s samples' % ((step + 1) * 64))

Where should I insert the tensorboard logging for the model graph?

最佳答案

执行此操作的最佳方法(TF 2.3.0、TB 2.3.0)是使用 tf.summary 并使用 @tf.function 通过包装函数传递模型装饰器。

在您的例子中,将模型图导出到 TensorBoard 以供检查:

inputs = keras.Input(shape=(784,), name='digits')
x = layers.Dense(64, activation='relu', name='dense_1')(inputs)
x = layers.Dense(64, activation='relu', name='dense_2')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
model = keras.Model(inputs=inputs, outputs=outputs)


# Solution Code:
writer = tf.summary.create_file_writer('./logs/')

@tf.function
def init_model(data, model):
    model(data)

tf.summary.trace_on(graph=True)
init_model(tf.zeros((1,784), model)

with writer.as_default():
    tf.summary.trace_export('name', step=0)

查看 TensorBoard 中的日志文件,您应该会看到这样的模型图。 enter image description here

关于python - 使用 Functional API 和 tf.GradientTape() 的组合在 Tensorflow 2.0 中训练时如何将模型图记录到张量板?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59041663/

相关文章:

python - 是否可以使用 python 安装 msi?

tensorflow - 当不再需要时如何从内存中释放张量?

tensorflow - 如何计算用从 .pb 文件加载的图形定义的 tensorflow 模型中可训练参数的总数?

python - Tensorflow - 内核似乎已经死亡。它将自动重新启动

python - 具有非常大的 HDF5 文件的 Tensorflow-IO 数据集输入管道

python - 动态创建方法

python - while 循环不会循环或执行

python - 预测方法为创建的模型给出错误

tensorflow2.0 - Tensorflow 2.0 : AttributeError: Tensor. 启用急切执行时名称无意义

python - Ruby optparse 限制