python - 将numpy uint8输入tensorflow float32占位符

标签 python numpy tensorflow type-conversion one-hot-encoding

我最近发现了一种概念验证实现,它使用 numpy.zeros 在 one-hot 编码中准备功能:

data = np.zeros((len(raw_data), n_input, vocab_size),dtype=np.uint8)

如上所示,单个的类型为 np.uint8。 检查模型后,我意识到tensorflow模型的输入占位符定义为tf.float32:

x = tf.placeholder(tf.float32, [None, n_input, vocab_size], name="onehotin")

我的具体问题: tensorflow 如何处理这种输入类型的“不匹配”。这些值 (0/1) 是否由 tensorflow 正确解释或转换。如果是这样,文档中是否提到了这一点。谷歌搜索后我找不到答案。应该提到的是,该模型的运行和值似乎是合理的。但是,将输入 numpy 特征输入为 np.float32 会导致需要大量内存。

相关性: 正在运行但经过错误训练的模型在采用输入管道/将模型投入生产后会有不同的行为。

最佳答案

Tensorflow 支持这样的数据类型转换。

x + 1等运算中,值1会经过tf.convert_to_tensor负责验证和转换的函数。有时会在后台手动调用该函数,并且当设置 dtype 参数时,该值会自动转换为此类型。

当您将数组输入到这样的占位符中时:

session.run(..., feed_dict={x: data})

...数据通过 np.asarray 调用显式转换为正确类型的 numpy 数组。源代码见python/client/session.py 。请注意,当数据类型不同时,此方法可能会重新分配缓冲区,而这正是您的情况所发生的情况。因此,您的内存优化并不完全按照您的预期工作:临时 32 位数据是在内部分配的。

关于python - 将numpy uint8输入tensorflow float32占位符,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49309679/

相关文章:

python - 如何构造嵌套的 numpy 记录数组?

python - 在花式索引时使用范围?

python - 如何为 TensorFlow 中的不同层或变量设置不同的学习率?

tensorflow - TensorFlow 中的生存分析

python - 如何使用 GCP Cloud SQL 作为数据流源和/或 Python 接收器?

python - 将 C++ double 返回给 Python?

python ,Scipy : Building triplets using large adjacency matrix

python - TensorFlow 入门中的 `my_input_fn` 如何允许枚举数据?

python - 如何使 pygal 将图例分布到 n 列或正确截断?

python - 数字是 float64 吗?