python - 使用 tf.Dataset 训练的模型进行推理

标签 python tensorflow tensorflow-datasets

我已经使用 tf.data.Dataset API 训练了一个模型,所以我的训练代码看起来像这样

with graph.as_default():
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(scale_features, num_parallel_calls=n_workers)
    dataset = dataset.shuffle(10000)
    dataset = dataset.padded_batch(batch_size, padded_shapes={...})
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_dataset.output_types,
                                                   train_dataset.output_shapes)
    batch = iterator.get_next()
    ... 
    # Model code
    ...
    iterator = dataset.make_initializable_iterator()

with tf.Session(graph=graph) as sess:
    train_handle = sess.run(iterator.string_handle())
    sess.run(tf.global_variables_initializer())
    for epoch in range(n_epochs):
        sess.run(train_iterator.initializer)
        while True:
            try:
                sess.run(optimizer, feed_dict={handle: train_handle})
            except tf.errors.OutOfRangeError:
               break

模型训练完成后,我想推断数据集中没有的示例,但我不确定如何去做。

明确一点,我知道如何使用另一个数据集,例如,我只是在测试时将句柄传递给我的测试集。

问题是关于给定缩放方案和网络需要句柄的事实,如果我想对未写入 TFRecord 的新示例进行预测,我将如何去做?

如果我要修改 batch,我会事先负责缩放,这是我想尽可能避免的事情。

那么我应该如何从模型中推断单个示例以 tf.data.Dataset 方式进行训练? (这不是出于生产目的,而是为了评估如果我更改特定功能会发生什么)

最佳答案

实际上图中有一个名为“IteratorGetNext:0”的张量名称 当你使用dataset api时,你可以使用下面的方式直接设置 输入:

#get a tensor from a graph 
input tensor : input = graph.get_tensor_by_name("IteratorGetNext:0")
# difine the target tensor you want evaluate for your prediction
prediction tensor: predictions=...
# finally call session to run 
then sess.run(predictions, feed_dict={input: np.asanyarray(images), ...})

关于python - 使用 tf.Dataset 训练的模型进行推理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50940667/

相关文章:

python - 具有阈值的累积销售数据形成具有 bool 值的新系列/列?

python - ValueError : Layer expects 2 input(s), 但它在训练 CNN 时收到 1 个输入张量

python - 如何使用 Keras ImageDataGenerator 预测单个图像?

python - 重置在 Tensorflow 2 数据集中到底意味着什么?

tensorflow - 如何从 Pandas DataFrame 到用于 NLP 的 Tensorflow BatchDataset?

python - 有没有可能如何在 Scribus 之外的 python 中使用 scribus 模块?

python bin 数据并返回 bin 中点(可能使用 pandas.cut 和 qcut)

python - 在 Python 中保护 HTTP 请求

Tensorflow 提示未检测到支持 CUDA 的设备

tensorflow - 有什么方法可以将 tensorflow lite (.tflite) 文件转换回 keras 文件 (.h5)?