python - 如何在 tensorflow 的回调中访问训练和测试数据?

标签 python tensorflow callback

>     import tensorflow as tf
>     
>     class MyMetric(tf.keras.callbacks.Callback):
>        def on_epoch_end(self,epoch,logs={}):
>            # how to access X_train and X_val here
> 
>     ...
>     model.fit(X_train,y_train,batch_size=32,epochs=10,validation_data=(X_val,y_val),shuffle=True,callbacks=[MyMetric()]

我正在尝试使用回调在tensorflow 2.0中实现自定义指标。在 on_epoch_end 方法中,我需要访问提供给 fit 方法的训练和验证数据(整个样本,而不是批处理)。有什么办法可以做到这一点吗?谢谢!

最佳答案

接受训练和测试数据集作为自定义回调类的初始化参数,然后在 on_epoch_end 方法中使用它。

类似这样的事情

class MyMetric(keras.callbacks.Callback):

  def __init__(self, X_test):
    self.X_test = X_test

在调用 fit 时,将测试集作为参数传递给您的自定义回调,如下所示

model.fit(X_train,y_train,batch_size=32,epochs=10,validation_data=(X_val,y_val),shuffle=True,callbacks=[MyMetric(X_test)]

有关 https://keras.io/guides/writing_your_own_callbacks/ 的更多详细信息

关于python - 如何在 tensorflow 的回调中访问训练和测试数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59166386/

相关文章:

javascript - jQuery 多个 animate() 回调

performance - TF 包装器 : Performance of Keras vs Tensorpack

java - Android Studio 中使用 OpenCV 和 Tensorflow 进行实时情绪检测

python - Tensorflow 对象检测 - 避免重叠框

swift - placesClient - 正确使用 sharedClient 方法?

Iframe 中定义的 Javascript 回调函数

python - 如何创建一个空数组并追加它?

python - 如何优雅地链接if else

python - 使用 matplotlib 围绕点旋转矩形

python - 当 CWD 发生变化时,如何在 Python 模块中使用相对路径?