python - 如何在 Tensorflow 中复制 PyTorch 的 nn.functional.unfold 函数?

标签 python tensorflow pytorch

我想用tensorflow重写pytorch的torch.nn.functional.unfold功能:

#input x:[16, 1, 50, 36]
x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)
#output x:[16, 180, 16]
我尝试使用函数tf.extract_image_patches() :x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], rates=[1, 1, 1, 1],padding='VALID')输入 x.shape : [16,1,64,98]我得到输出 x.shape : [16,1,20,490]然后我 reshape X[16,490,20] ,这是我所期望的。
但是当我提供数据时出现错误:
UnimplementedError (see above for traceback): Only support ksizes across space.
[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]
如何使用 tensorflow 重写 pytorch torch.nn.functional.unfold更改X的功能?

最佳答案

x = tf.reshape(x, [16, 50, 36, 1])
x = tf.extract_image_patches(x, ksizes=[1, 4, 98, 1], strides=[1, 4, 1, 1], rates=[1, 1, 1, 1], padding='VALID')

关于python - 如何在 Tensorflow 中复制 PyTorch 的 nn.functional.unfold 函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64523441/

相关文章:

python - 当使用 .clamp 而不是 torch.relu 时,Pytorch Autograd 会给出不同的渐变

python - 为什么我的 pytorch NN 返回一个 nan 的张量?

python - Pandas 滚动窗口返回一个数组

python - uwsgi 在作为 systemd 服务运行时损坏管道

python - find_elements_by_xpath 不包含 3 次特定字符串

python - cx_Freeze 无法包含 Cython .pyx 模块

python - @nd time 给出错误,在 TENSORFLOW 执行期间

python - 如何使用或安装pycocoevalcap?

python - 在 tensorflow 中重用关闭的 session

python - 如何解决与输入大小 (torch.Size([1])) 不同的 UserWarning : Using a target size (torch. Size([]))?