python - 如何将张量转换为字符串

标签 python tensorflow tensorflow-datasets

我正在测试 tf.data(),这是现在推荐的批量输入数据的方式, 但是,我正在加载自定义数据集,因此我需要“str”格式的文件名。但是在创建 tf.Dataset.from_tensor_slices 时,它们是 Tensor 对象。

def load_image(file, label):
        nifti = np.asarray(nibabel.load(file).get_fdata()) # <- here is the problem

        xs, ys, zs = np.where(nifti != 0)
        nifti = nifti[min(xs):max(xs) + 1, min(ys):max(ys) + 1, min(zs):max(zs) + 1]
        nifti = nifti[0:100, 0:100, 0:100]
        nifti = np.reshape(nifti, (100, 100, 100, 1))
        nifti = tf.convert_to_tensor(nifti, np.float32)
        return nifti, label


    def load_image_wrapper(file, labels):
       file = tf.py_function(load_image, [file, labels], (tf.string, tf.int32))
       return file


    dataset = tf.data.Dataset.from_tensor_slices((train, labels))
    dataset = dataset.map(load_image_wrapper, num_parallel_calls=6)
    dataset = dataset.batch(6)
    dataset = dataset.prefetch(buffer_size=6)
    iterator = iter(dataset)
    batch_of_images = iterator.get_next()

这里是错误: typeerror expected str bytes 或 os.pathlike object not Tensor

我试过使用“py_function”包装器,但无济于事。 有什么想法吗?

最佳答案

解决了问题,TensorFlow 2.1:

    def load_image(file, label):
    nifti = np.asarray(nibabel.load(file.numpy().decode('utf-8')).get_fdata())

    xs, ys, zs = np.where(nifti != 0)
    nifti = nifti[min(xs):max(xs) + 1, min(ys):max(ys) + 1, min(zs):max(zs) + 1]
    nifti = nifti[0:100, 0:100, 0:100]
    nifti = np.reshape(nifti, (100, 100, 100, 1))
    nifti = tf.convert_to_tensor(nifti, np.float64)
    return nifti, label


def load_image_wrapper(file, labels):
    return tf.py_function(load_image, [file, labels], [tf.float64, tf.float64])


dataset = tf.data.Dataset.from_tensor_slices((train, labels))
dataset = dataset.map(load_image_wrapper, num_parallel_calls=6)
dataset = dataset.batch(2)
dataset = dataset.prefetch(buffer_size=2)
iterator = iter(dataset)
batch_of_images = iterator.get_next()

关于python - 如何将张量转换为字符串,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60496435/

相关文章:

tensorflow - 如何评估 tensorflow 估计器每个时期的测试数据集

python - numpy Loadtxt 函数似乎消耗了太多内存

computer-vision - 如何在 TensorFlow 中为场景标签实现逐像素分类?

Tensorflow/models 使用 COCO 90 类 ID 虽然 COCO 只有 80 个类别

python - 我应该如何使用 Keras(和 TensorFlow)编写负二项分布的损失函数?

python - 为 RGB 图片校正 tf.nn.l2_normalize(x, dim, epsilon=1e-12, name=None) 中的暗淡参数

python - 我在 Tensorflow (python 3.7) 中加载数据集时遇到问题

python - 使用 pyenv 安装 Python 3 后无法导入 tkinter

python - 返回节点邻域中最大的节点

python - 为什么这段 C 代码通过 Python 的 ctypes 运行时的执行速度是直接运行时的一半?