python-3.x - 如何在 tensorflow 2 数据集中使用带有元组的 map ?

标签 python-3.x tensorflow-datasets tensorflow2.0

尝试将元组映射到 tf 2 数据集中的元组(请参见下面的代码)。我的输出(请参见下文)显示 map 函数仅被调用一次。我似乎无法获取元组。

我如何从输入参数 a 中获取“a”、“b”、“c”:

tuple Tensor("args_0:0", shape=(3,), dtype=string)
type <class 'tensorflow.python.framework.ops.Tensor'>

编辑:似乎使用 Dataset.from_tensor_slices 一次生成所有数据。这解释了为什么 map 只被调用一次。所以我可能需要以其他方式制作数据集。

from __future__ import absolute_import, division, print_function, unicode_literals
from timeit import default_timer as timer
print('import tensorflow')
start = timer()
import tensorflow as tf
end = timer()
print('Elapsed time: ' + str(end - start),"for",tf.__version__)
import numpy as np
def map1(tuple):
    print("<<<")
    print("tuple",tuple)
    print("type",type(tuple))
    print("shape",tuple.shape)
    print("tuple 0",tuple[0])
    print("type 0",type(tuple[0]))
    print("shape 0",tuple.shape[0])
    # how do i get "a","b","c" from the input parameter?
    print(">>>")
    return ("1","2","3")
l=[]
l.append(("a","b","c"))
l.append(("d","e","f"))
print(l)
ds=tf.data.Dataset.from_tensor_slices(l)
print("ds",ds)
print("start mapping")
result = ds.map(map1)
print("end mapping")


$ py mapds.py
import tensorflow
Elapsed time: 12.002168990751619 for 2.0.0
[('a', 'b', 'c'), ('d', 'e', 'f')]
ds <TensorSliceDataset shapes: (3,), types: tf.string>
start mapping
<<<
tuple Tensor("args_0:0", shape=(3,), dtype=string)
type <class 'tensorflow.python.framework.ops.Tensor'>
shape (3,)
tuple 0 Tensor("strided_slice:0", shape=(), dtype=string)
type 0 <class 'tensorflow.python.framework.ops.Tensor'>
shape 0 3
>>>
end mapping

最佳答案

map 函数 (map1) 返回的一个或多个值决定返回数据集中每个元素的结构。 [Ref]
在您的情况下,result 是一个 tf 数据集,您的编码没有任何问题。

要检查每个元组是否正确映射,您可以像下面这样遍历数据集的每个样本:
[更新代码]

    def map1(tuple):
        print(tuple[0].numpy().decode("utf-8")) # Print first element of tuple
        return ("1","2","3")
    l=[]
    l.append(("a","b","c"))
    l.append(("d","e","f"))
    ds=tf.data.Dataset.from_tensor_slices(l)
    ds = ds.map(lambda tpl: tf.py_function(map1, [tpl], [tf.string, tf.string, tf.string]))

    for sample in ds:
        print(str(sample[0].numpy().decode()), sample[1].numpy().decode(), sample[2].numpy().decode())
Output:
a
1 2 3
d
1 2 3

希望对您有所帮助。

关于python-3.x - 如何在 tensorflow 2 数据集中使用带有元组的 map ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58969880/

相关文章:

python - TF2.0 Data API 从每个类标签中获取 n_i 个样本

c - 通过 C API 访问 tensorflow 2.0 SavedModel 的输入和输出张量

python - tensorflow 在损失函数中使用输入

python - Pyqt QTablewidget 自动换行

python - 如何使用 gpython 脚本从 sample.tar.gz 文件中提取选定的文件

tensorflow - 通过 feature_columns 使用 Dataset API 将自由文本特征引入 Tensorflow Canned Estimators

tensorflow2.0 - tensorflow 2.x : Convert byte string to int in map

python - Tensorflow 2 + Keras 的知识蒸馏损失

python-3.x - 如何通过 SOAP API 锁定 virtualbox 以获取屏幕截图

python - 如何解析 python 中的 requests.session cookie 属性