python - tensorflow 逐元素乘法广播?

标签 python tensorflow

tensorflow 是否提供在最后一个维度上进行元素乘法广播的功能?

以下是我正在尝试执行的操作和不起作用的示例:

import tensorflow as tf
x = tf.constant(5, shape=(1, 200, 175, 6), dtype=tf.float32)
y = tf.constant(1, shape=(1, 200, 175), dtype=tf.float32)
tf.math.multiply(x, y)

本质上,我希望对于沿最后一个维度的每个 x 切片,与 y 进行逐元素矩阵乘法。

我发现这个问题询问类似的操作:Efficient element-wise multiplication of a matrix and a vector in TensorFlow

不幸的是,建议的方法(使用 tf.multiply() )现在不再有效。相应的 tf.math.multiply 也不起作用,因为上面的代码给出了以下错误:

Traceback (most recent call last):
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1864, in _create_c_op
    c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 175 and 200 for 'Mul' (op: 'Mul') with input shapes: [1,200,175,6], [1,200,175].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py", line 322, in multiply
    return gen_math_ops.mul(x, y, name)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 6490, in mul
    "Mul", x=x, y=y, name=name)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2027, in __init__
    control_input_ops)
  File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1867, in _create_c_op
    raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 175 and 200 for 'Mul' (op: 'Mul') with input shapes: [1,200,175,6], [1,200,175].

我可以想到一种可行的方法:将 y 重复 6 次,使其与 x 具有完全相同的形状,然后进行逐元素乘法。

但是在 tensorflow 中是否有更快且内存有效的方法来做到这一点?

最佳答案

这应该可以实现你想要的:

x = np.array([[[1,2,3],[4,5,6],[7,8,9],[10,11,12]]])
# [[[ 1  2  3]
#   [ 4  5  6]
#   [ 7  8  9]
#   [10 11 12]]]
y = np.array([[1,2,3,4]])
# [[1 2 3 4]]
y = tf.expand_dims(y, axis=-1)
mul = tf.multiply(x, y)
# [[[ 1  2  3]
#   [ 8 10 12]
#   [21 24 27]
#   [40 44 48]]]

最后,使用您需要的形状:

x = np.random.rand(1, 200, 175, 6)
y = np.random.rand(1, 200, 175)
y = tf.expand_dims(y, axis=-1)
mul = tf.multiply(x, y)
with tf.Session() as sess:
    print(sess.run(mul).shape)
    # (1, 200, 175, 6)
​

关于python - tensorflow 逐元素乘法广播?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57617294/

相关文章:

python - curve_fit 和 scipy.odr 的比较 - 绝对西格玛

python - 如何使用 python 通过 beautifulsoup 检索文本

python - 使用 Python - 从一些 html 中获取表格并显示它?

python - LSTM 神经网络输入/输出维度错误

python - 命令 "python setup.py egg_info"失败,错误代码为 1 in/tmp/pip-build-hn4hol_z/spacy/

python - 使用 TensorRT 部署语义分割网络(U-Net)(不支持上采样)

python - 在以 tensorflow 概率层结尾的网络上使用生成器

python - 'tensorflow_core.estimator' 没有属性 'inputs' ,为什么会发生这种情况?

tensorflow tf.maximum(0, x) 返回错误

python - 如何以 snmp OID 作为键获取 python 字典作为对象