是否可以使用 Tensorflow 数据类型 tf.dtypes.DType例如Python类型提示中的tf.int32?
from typing import (
Union,
)
import tensorflow as tf
import numpy as np
def f(
a: Union[tf.int32, tf.float32] # <----
):
return a * 2
def g(a: Union[np.int32, np.float32]):
return a * 2
def test_a():
f(tf.cast(1.0, dtype=tf.float32)) # <----
g(np.float32(1.0)) # Numpy type has no issue
它会导致下面的错误,并想知道这是否可能。python3.8/typing.py:149: in _type_check
raise TypeError(f"{msg} Got {arg!r:.100}.")
E TypeError: Union[arg, ...]: each arg must be a type. Got tf.int32.
最佳答案
我假设您希望您的功能接受:
tf.float32
np.float32
float
tf.int32
np.int32
int
并且总是返回,比如说,
tf.float32
.不完全确定这是否涵盖您的用例,但我会为您的输入参数放置一个广泛的类型,并在您的函数中转换为所需的类型。experimental_follow_type_hints
可以与类型注释一起使用,通过减少昂贵的图形回溯次数来提高性能。例如,即使输入是非 Tensor 值,用 tf.Tensor 注释的参数也会转换为 Tensor。from typing import TYPE_CHECKING
import tensorflow as tf
import numpy as np
@tf.function(experimental_follow_type_hints=True)
def foo(x: tf.Tensor) -> tf.float32:
if x.dtype == tf.int32:
x = tf.dtypes.cast(x, tf.float32)
return x * 2
a = tf.cast(1.0, dtype=tf.float32)
b = tf.cast(1.0, dtype=tf.int32)
c = np.float32(1.0)
d = np.int32(1.0)
e = 1.0
f = 1
for var in [a, b, c, d, e, f]:
print(f"input: {var},\tinput type: {type(var)},\toutput: {foo(var)}\toutput type: {type(foo(var))}")
if TYPE_CHECKING:
reveal_locals()
python3 stack66968102.py
的输出:input: 1.0, input type: <class 'tensorflow.python.framework.ops.EagerTensor'>, output: 2.0 output dtype: <dtype: 'float32'>
input: 1, input type: <class 'tensorflow.python.framework.ops.EagerTensor'>, output: 2.0 output dtype: <dtype: 'float32'>
input: 1.0, input type: <class 'numpy.float32'>, output: 2.0 output dtype: <dtype: 'float32'>
input: 1, input type: <class 'numpy.int32'>, output: 2.0 output dtype: <dtype: 'float32'>
input: 1.0, input type: <class 'float'>, output: 2.0 output dtype: <dtype: 'float32'>
input: 1, input type: <class 'int'>, output: 2.0 output dtype: <dtype: 'float32'>
mypy stack66968102.py --ignore-missing-imports
的输出:stack66968102.py:27: note: Revealed local types are:
stack66968102.py:27: note: a: Any
stack66968102.py:27: note: b: Any
stack66968102.py:27: note: c: numpy.floating[numpy.typing._32Bit*]
stack66968102.py:27: note: d: numpy.signedinteger[numpy.typing._32Bit*]
stack66968102.py:27: note: e: builtins.float
stack66968102.py:27: note: f: builtins.int
stack66968102.py:27: note: tf: Any
stack66968102.py:27: note: var: Any
关于python 类型提示 - 可以使用 tensorflow 数据类型吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66968102/