python - TensorFlow 中相当于 PyTorch 中的 expand() 的函数是什么?

标签 python tensorflow pytorch

假设我有一个 2 x 3 矩阵,我想创建一个 6 x 2 x 3 矩阵,其中第一维中的每个元素都是原始 2 x 3 矩阵。

在 PyTorch 中,我可以这样做:

import torch
from torch.autograd import Variable
import numpy as np

x = np.array([[1, 2, 3], [4, 5, 6]])
x = Variable(torch.from_numpy(x))

# y is the desired result
y = x.unsqueeze(0).expand(6, 2, 3)

在 TensorFlow 中执行此操作的等效方法是什么?我知道 unsqueeze() 等同于 tf.expand_dims() 但我不知道 TensorFlow 有任何等同于 expand() 的东西。我正在考虑在 1 x 2 x 3 张量列表上使用 tf.concat,但我不确定这是否是最好的方法。

最佳答案

pytorch expand 的等效函数是 tensorflow tf.broadcast_to

文档:https://www.tensorflow.org/api_docs/python/tf/broadcast_to

关于python - TensorFlow 中相当于 PyTorch 中的 expand() 的函数是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48226221/

相关文章:

python - 如何使用 sql alchemy 进行联接查询?

python - 有没有把分类和回归合二为一的机器学习算法

python - PycURL 通过特定接口(interface)发送 DNS 流量

tensorflow - 图像分类迁移学习需要负例吗?

tensorflow - 为什么卷积神经网络的张量维度是给定的? - TensorFlow

python - pytorch 中 expand 的 numpy 等价物是什么?

python - Pandas 数据帧行到列表字典,使用每行的第一个值作为键

python - 剪切渐变时出错

keras - 如何加载 .bin 模型文件?

pytorch - (Pytorch)为什么 conv2d 结果都不同。它们的数据类型都是整数,没有 float