我最近发现了一种概念验证实现,它使用 numpy.zeros
在 one-hot 编码中准备功能:
data = np.zeros((len(raw_data), n_input, vocab_size),dtype=np.uint8)
如上所示,单个的类型为 np.uint8
。
检查模型后,我意识到tensorflow模型的输入占位符定义为tf.float32
:
x = tf.placeholder(tf.float32, [None, n_input, vocab_size], name="onehotin")
我的具体问题:
tensorflow 如何处理这种输入类型的“不匹配”。这些值 (0/1)
是否由 tensorflow 正确解释或转换。如果是这样,文档中是否提到了这一点。谷歌搜索后我找不到答案。应该提到的是,该模型的运行和值似乎是合理的。但是,将输入 numpy 特征输入为 np.float32
会导致需要大量内存。
相关性: 正在运行但经过错误训练的模型在采用输入管道/将模型投入生产后会有不同的行为。
最佳答案
Tensorflow 支持这样的数据类型转换。
在x + 1
等运算中,值1
会经过tf.convert_to_tensor
负责验证和转换的函数。有时会在后台手动调用该函数,并且当设置 dtype 参数时,该值会自动转换为此类型。
当您将数组输入到这样的占位符中时:
session.run(..., feed_dict={x: data})
...数据通过 np.asarray 调用显式转换为正确类型的 numpy 数组。源代码见python/client/session.py
。请注意,当数据类型不同时,此方法可能会重新分配缓冲区,而这正是您的情况所发生的情况。因此,您的内存优化并不完全按照您的预期工作:临时 32 位数据
是在内部分配的。
关于python - 将numpy uint8输入tensorflow float32占位符,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49309679/