python - 无效参数错误 : cannot compute Sub as input #1(zero-based) was expected to be a uint8 tensor but is a float tensor [Op:Sub]

标签 python tensorflow keras

问题

请帮忙了解错误原因及解决方法。

代码

import tensorflow as tf
import numpy as np

fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_full = np.concatenate((x_train, x_test), axis=0)

layer = tf.keras.layers.experimental.preprocessing.Normalization()
layer.adapt(x_full)
layer(x_train)

错误

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-16-699c47b6db55> in <module>
----> 1 ds = layer(x_train)

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    966           with base_layer_utils.autocast_context_manager(
    967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
    969           self._handle_activity_regularization(inputs, outputs)
    970           self._set_mask_metadata(inputs, outputs, input_masks)

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/keras/layers/preprocessing/normalization.py in call(self, inputs)
    109     mean = array_ops.reshape(self.mean, self._broadcast_shape)
    110     variance = array_ops.reshape(self.variance, self._broadcast_shape)
--> 111     return (inputs - mean) / math_ops.sqrt(variance)
    112 
    113   def compute_output_shape(self, input_shape):

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
    982     with ops.name_scope(None, op_name, [x, y]) as name:
    983       if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 984         return func(x, y, name=name)
    985       elif not isinstance(y, sparse_tensor.SparseTensor):
    986         try:

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in sub(x, y, name)
  10098         pass  # Add nodes to the TensorFlow graph.
  10099     except _core._NotOkStatusException as e:
> 10100       _ops.raise_from_not_ok_status(e, name)
  10101   # Add nodes to the TensorFlow graph.
  10102   _, _, _op, _outputs = _op_def_library._apply_op_helper(

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6651   message = e.message + (" name: " + name if name is not None else "")
   6652   # pylint: disable=protected-access
-> 6653   six.raise_from(core._status_to_exception(e.code, message), None)
   6654   # pylint: enable=protected-access
   6655 

~/conda/envs/tensorflow/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a uint8 tensor but is a float tensor [Op:Sub]

尝试

尝试了 dtype arg 但同样的错误。

layer = tf.keras.layers.experimental.preprocessing.Normalization(dtype='float32')

除以 1.0 解决了问题,但不确定原始原因。

x_full = np.concatenate((x_train, x_test), axis=0) / 1.0
x_train = x_train / 1.0

Keras 只能使用 float32 吗?

相关问题

最佳答案

原因是 preprocessing.Normalization expect float32 但你的数据是 uint8 并因此出现错误。

这实际上是 Tensorflow 的问题,而不是 Keras 本身,因为这是更快的计算。

提醒:float 和 int 在处理器的不同位置计算,每个处理器在不同数据类型上有不同的性能,例如 nvidia 的 gpus 使用 float32float16 更快而 16 位的 arm cpu 更快。

Pytorch 也需要两个变量是相同的数据类型,否则它将无法工作。

在 python 中将一个整数与一个 float 相除会自动得到一个新的 float ,x_train = x_train/1.0 将使 x_train float32(或float64float16 取决于你在 ~/.keras/keras.json 中的内容,但你在这里有 float32 ).

关于python - 无效参数错误 : cannot compute Sub as input #1(zero-based) was expected to be a uint8 tensor but is a float tensor [Op:Sub],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64243290/

相关文章:

android - Tensorflow Lite : No module named tf. contrib.lite

Keras Tensorflow 后端未检测到 GPU

python - 如何修复 'Expected to see 2 array(s), but instead got the following list of 1 arrays'

python - 使用错误的输入类型调用函数时打印 "Wrong Type"- Python

Python Postgres - psycopg2.ProgrammingError : no results to fetch

python - 有没有比我做的更好的方法来猜测可能的未知变量而不用蛮力?机器学习?

java - 使用 TensorFlow for Java 的内存泄漏

python - groupby 多个值列

python - 如何在 python 中使用列表理解等于多个变量?

tensorflow - 将卡住模型 '.pb' 文件转换为 '.tflite' 文件所需的参数 input_arrays 和 output_arrays 是什么?