python 类型提示 - 可以使用 tensorflow 数据类型吗?

标签 python tensorflow type-hinting

是否可以使用 Tensorflow 数据类型 tf.dtypes.DType例如Python类型提示中的tf.int32?

from typing import (
    Union,
)
import tensorflow as tf
import numpy as np


def f(
    a: Union[tf.int32, tf.float32]  # <----
): 
    return a * 2


def g(a: Union[np.int32, np.float32]):
    return a * 2


def test_a():
    f(tf.cast(1.0, dtype=tf.float32))  # <----
    g(np.float32(1.0))                 # Numpy type has no issue

它会导致下面的错误,并想知道这是否可能。
python3.8/typing.py:149: in _type_check
    raise TypeError(f"{msg} Got {arg!r:.100}.")
E   TypeError: Union[arg, ...]: each arg must be a type. Got tf.int32.

最佳答案

我假设您希望您的功能接受:

  • tf.float32
  • np.float32
  • float
  • tf.int32
  • np.int32
  • int

  • 并且总是返回,比如说,tf.float32 .不完全确定这是否涵盖您的用例,但我会为您的输入参数放置一个广泛的类型,并在您的函数中转换为所需的类型。
    experimental_follow_type_hints 可以与类型注释一起使用,通过减少昂贵的图形回溯次数来提高性能。例如,即使输入是非 Tensor 值,用 tf.Tensor 注释的参数也会转换为 Tensor。
    from typing import TYPE_CHECKING
    import tensorflow as tf
    import numpy as np
    
    
    @tf.function(experimental_follow_type_hints=True)
    def foo(x: tf.Tensor) -> tf.float32:
        if x.dtype == tf.int32:
            x = tf.dtypes.cast(x, tf.float32)
        return x * 2
    
    a = tf.cast(1.0, dtype=tf.float32)
    b = tf.cast(1.0, dtype=tf.int32)
    
    c = np.float32(1.0)
    d = np.int32(1.0)
    
    e = 1.0
    f = 1
    
    for var in [a, b, c, d, e, f]:
        print(f"input: {var},\tinput type: {type(var)},\toutput: {foo(var)}\toutput type: {type(foo(var))}")
    
    if TYPE_CHECKING:
        reveal_locals()
    
    python3 stack66968102.py 的输出:
    input: 1.0,     input type: <class 'tensorflow.python.framework.ops.EagerTensor'>,      output: 2.0     output dtype: <dtype: 'float32'>
    input: 1,       input type: <class 'tensorflow.python.framework.ops.EagerTensor'>,      output: 2.0     output dtype: <dtype: 'float32'>
    input: 1.0,     input type: <class 'numpy.float32'>,    output: 2.0     output dtype: <dtype: 'float32'>
    input: 1,       input type: <class 'numpy.int32'>,      output: 2.0     output dtype: <dtype: 'float32'>
    input: 1.0,     input type: <class 'float'>,    output: 2.0     output dtype: <dtype: 'float32'>
    input: 1,       input type: <class 'int'>,      output: 2.0     output dtype: <dtype: 'float32'>
    
    mypy stack66968102.py --ignore-missing-imports 的输出:
    stack66968102.py:27: note: Revealed local types are:
    stack66968102.py:27: note:     a: Any
    stack66968102.py:27: note:     b: Any
    stack66968102.py:27: note:     c: numpy.floating[numpy.typing._32Bit*]
    stack66968102.py:27: note:     d: numpy.signedinteger[numpy.typing._32Bit*]
    stack66968102.py:27: note:     e: builtins.float
    stack66968102.py:27: note:     f: builtins.int
    stack66968102.py:27: note:     tf: Any
    stack66968102.py:27: note:     var: Any
    

    关于python 类型提示 - 可以使用 tensorflow 数据类型吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66968102/

    相关文章:

    python - 为什么 Conda 虚拟环境如此之大?

    python - 如何使用 pandas/python 从日期和字符串的列组合中删除时间戳?

    python - 类型提示子类返回 self

    python - 引号中的 mypy 显式类型提示仍然给出未定义的错误

    Python类型提示——为dict子类指定key、value类型

    python - Django 测试运行器因 "relation does not exist"错误而失败

    python - 如何展平嵌套的字典并在碰撞时使用内部值

    python - 如何使Keras网络不输出全1

    Tensorflow - 如何批量访问每个示例的损失值?

    tensorflow - 在 Tensorflow Serving 中调试批处理(未观察到效果)