python - 在Tensorflow 2.0中的tf.function input_signature中使用字典

标签 python tensorflow2.0

我正在使用Tensorflow 2.0并面临以下情况:

@tf.function
def my_fn(items):
    .... #do stuff
    return

如果items是张量的字典,例如:
item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}

有没有一种使用tf.function的input_signature参数的方法,所以我可以强制tf2避免在item1例如tf.zeros([2,1])时创建多个图形?

最佳答案

输入签名必须是一个列表,但是列表中的元素可以是字典或张量规范列表。在您的情况下,我会尝试:( name属性是可选的)

signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
                   "item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") } 
              

# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
    .... # do stuff
    return

# calling the TensorFlow function
my_fun(items)
但是,如果要调用my_fn创建的特定具体函数,则必须解压缩字典。您还必须在name中提供tf.TensorSpec属性。
# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs 
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
                                             
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)
这很烦人,但应该在TensorFlow 2.3中解决。 (请参见TF“具体功能”指南的末尾)

关于python - 在Tensorflow 2.0中的tf.function input_signature中使用字典,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60827999/

相关文章:

python - tensorflow 在损失函数中使用输入

c++ - 在 CPP 中运行 tensorflow 模型

tensorflow - Tensorflow 中 LSTM 层的门权重顺序

python - 将 AST 节点转换为向量/数字

python - webroot 上的 django urlconf 404

python - 带有输入文件的 Matplotlib 中的直方图

python - 检查一首歌是否在pygame中结束

python - 根据另一列向日期添加一天

python - 如何为 QDockWidget 的选项卡和窗口标题设置不同的文本?

tensorflow - 如何使用TF2将加载的h5模型正确保存到pb