python - 使用 tf.data.Datasets 卡住 Tensorflow 图时确定输入节点

标签 python tensorflow

我使用 Tensorflow tf.data.Dataset API 作为我的输入管道,如下所示:

train_dataset = tf.data.Dataset.from_tensor_slices((trn_X,trn_y))
train_dataset = 
train_dataset.map(_trn_parse_function,num_parallel_calls=12)
train_dataset = 
train_dataset.shuffle(buffer_size=1000).repeat(args.num_epochs)# 
.batch(args.batch_size)
train_dataset = train_dataset.apply(tf.contrib.data.batch_and_drop_remainder(args.batch_size))
train_dataset = train_dataset.prefetch(buffer_size=600)



val_dataset = tf.data.Dataset.from_tensor_slices((val_X,val_y))
val_dataset = val_dataset.map(_val_parse_function,num_parallel_calls=4)
val_dataset = val_dataset.repeat(1)
val_dataset = val_dataset.apply(tf.contrib.data.batch_and_drop_remainder(args.batch_size))
val_dataset = val_dataset.prefetch(buffer_size=200)


handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, 
train_dataset.output_shapes)

images,labels = iterator.get_next()


train_iter = train_dataset.make_initializable_iterator()
val_iter = val_dataset.make_initializable_iterator()

然后使用此代码在训练和验证数据集之间切换:

# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
sess.run(train_iter.initializer)
sess.run(val_iter.initializer)

...
loss = sess.run([train_op],feed_dict={handle:training_handle, 
is_training:True})

训练后,我保存权重,然后将图形从保存的检查点 ((.meta) 卡住为 .pb 格式。随后,运行 optimize_for_inference.py tensorflow repo 中提供的工具。此脚本需要定义 input_nodes_names。我无法确定哪个是图形的正确输入节点。以下是我的图形的节点:

['Variable/initial_value',
'Variable',
'Variable/Assign',
'Variable/read',
'increment_global_step/value',
'increment_global_step',
'Placeholder',
'is_training',
'tensors/component_0',
'tensors/component_1',
'num_parallel_calls',
'batch_size',
'count',
'buffer_size',
'OneShotIterator',
'IteratorToStringHandle',
'IteratorGetNext',
....
....
'output/Softmax]

可以很容易地确定输出节点,但不能确定输入节点。

最佳答案

handle = tf.placeholder(tf.string, shape=[]) 是您的输入,因此张量很可能是“Placeholder:0”。

不过这样写会更有意义:

handle = tf.placeholder(tf.string, shape=[], name="input_placeholder")

那你肯定知道。

关于python - 使用 tf.data.Datasets 卡住 Tensorflow 图时确定输入节点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50955127/

相关文章:

python - Tensorflow 模型评估基于批量大小

python - Sympysolvet() 在应该返回数值答案时返回 ConditionSet()

Python 相当于 Perl 的 'w' 打包格式

python - 如果 range() 是 Python 3.3 中的生成器,为什么我不能在范围上调用 next()?

Tensorflow如何检查模型

python - 训练后Keras合并单张图片的批处理

python - 如何在 ML pyspark 管道中添加我自己的函数作为自定义阶段?

python - 根据文本版本控制删除重复项

python - 动态编辑 Tensorflow 对象检测的管道配置

javascript - 在 tensorflow.js 中设置权重的函数初始值设定项