python - 如何在 tensorflow 中将序列映射到序列?

标签 python tensorflow

我有一个 3 维形状矩阵(高度、宽度、4)。事实上,它是一个位图,每个像素都有 RGBA 值。我想将每个 RGBA 集减少为具有两个值的集,例如 [x,y]。

查看 imgur com/Blr2EQC 上的图片

我尝试过使用map_fn

import cv2
import tensorflow as tf

def map_pixel_to_vector(elt):
    b = elt[0] - 127
    g = elt[1] - 127
    r = elt[2] - 127
    a = elt[3] - 127

    dx = (g * 127) + r
    dy = (a * 127) + b
    return [dx,dy]

file = "example.png"
frame = cv2.imread(file, cv2.IMREAD_UNCHANGED
s = tf.shape(frame)

# reshape to list of pixels
elts = tf.reshape(frame, (s[0]*s[1],4))

# cast from uint8 to int32 to support negative output
elts = tf.dtypes.cast(elts, tf.int32)

# map each pixel to output
elts = tf.map_fn(map_pixel_to_vector, elts)

# reshape back to image resolution
elts = tf.reshape(elts, (s[0], s[1], 2)

现在我希望这能起作用,每个 [rgba] 像素都会减少到 [xy] 像素,但我得到的是

ValueError: The two structures don't have the same nested structure.

First structure: type=DType str=<dtype: 'int32'>

Second structure: type=list str=[<tf.Tensor: id=262537, shape=(), dtype=int32, numpy=98>, <tf.Tensor: id=262540, shape=(), dtype=int32, numpy=210>]

More specifically: Substructure "type=list str=[<tf.Tensor: id=262537, shape=(), dtype=int32, numpy=98>, <tf.Tensor: id=262540, shape=(), dtype=int32, numpy=210>]" is a sequence, while substructure "type=DType str=<dtype: 'int32'>" is not

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 97, in <module>
    loss = loss_fn(exc, [outputs[-1]], [inputs[-1]])
  File "main.py", line 36, in loss_fn
    elts = tf.map_fn(reduce_pixel_to_vector, elts)
  File "/usr/lib/python3.7/site-packages/tensorflow_core/python/ops/map_fn.py", line 268, in map_fn
    maximum_iterations=n)
  File "/usr/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2714, in while_loop
    loop_vars = body(*loop_vars)
  File "/usr/lib/python3.7/site-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2705, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "/usr/lib/python3.7/site-packages/tensorflow_core/python/ops/map_fn.py", line 258, in compute
    nest.assert_same_structure(dtype or elems, packed_fn_values)
  File "/usr/lib/python3.7/site-packages/tensorflow_core/python/util/nest.py", line 313, in assert_same_structure
    % (str(e), str1, str2))

任何帮助将不胜感激。

最佳答案

您的函数map_pixel_to_vector返回一个列表,而不是张量。您可以将其变成张量,例如使用 tf.stacktf.convert_to_tensor :

def map_pixel_to_vector(elt):
    b = elt[0] - 127
    g = elt[1] - 127
    r = elt[2] - 127
    a = elt[3] - 127

    dx = (g * 127) + r
    dy = (a * 127) + b
    return tf.stack([dx, dy])

但是,您可以在没有 tf.map_fn 的情况下执行相同的操作更简单有效,如下所示:

import tensorflow as tf
import cv2

file = "example.png"
frame = tf.constant(cv2.imread(file, cv2.IMREAD_UNCHANGED))
elts = tf.dtypes.cast(frame, tf.int32)
r, g, b, a = tf.unstack(elts - 127, num=4, axis=-1)
elts = tf.stack([(g * 127) + r, (a * 127) + b], axis=-1)

关于python - 如何在 tensorflow 中将序列映射到序列?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57972913/

相关文章:

python - Pytest 的 "caplog" fixture 的类型提示是什么?

python - 在 matplotlib 图中更改 X 刻度

python - Peewee 和更新

tensorflow - 在 tensorflow 中检索未命名的变量

python - 为图像数据集打乱补丁的更好方法 - tf.data 输入管道

python - 如何检查 TPU 设备类型是 v2 还是 v3?

python - Python 对 FreeBSD 的支持是否与对 CentOS/Ubuntu/其他 Linux 版本的支持一样好?

Tensorflow,用 tf.train.Saver 保存了什么?

python - 使用 fit_generator 的训练模型不显示 val_loss 和 val_acc 并且在第一个时期中断

python - 分布式 TensorFlow [异步,图间复制] : which are the exactly interaction between workers and servers regarding Variables update