python - 如何将两个形状为 (2,2) 和 (2,) 的 tf.Variable 类型数组相乘?

标签 python python-3.x numpy tensorflow

我正在尝试转换代码

import numpy as np
W_np = np.ones((2,2))
x_np = np.array([1,0])
print(W_np @ x_np)

output: array([1., 1.])

转换为 TensorFlow 等效项

import tensorflow as tf
tf.enable_eager_execution()
W = tf.Variable(tf.ones(shape=(2,2)), name="W", dtype=tf.float32)
x = tf.Variable(np.array([1,0]), name = 'x', dtype=tf.float32)
W @ x

但收到错误消息

InvalidArgumentError: In[1] is not a matrix. Instead it has shape [2] [Op:MatMul] name: matmul/

如何解决?

最佳答案

接下来的一个应该可以工作:

import tensorflow as tf
import numpy as np

tf.enable_eager_execution()
W = tf.Variable(tf.ones(shape=(2,2)), name="W", dtype=tf.float32)
x = tf.Variable(np.array([1,0]), name = 'x', dtype=tf.float32)
res =tf.linalg.matmul(W, tf.expand_dims(x, 0), transpose_b=True)
print(res)
# tf.Tensor([[1.] [1.]], shape=(2, 1), dtype=float32)

# or slightly different way
res =tf.linalg.matmul(tf.expand_dims(x, 0), W)
res = tf.squeeze(res)
print(res)
# tf.Tensor([1. 1.], shape=(2,), dtype=float32)

关于python - 如何将两个形状为 (2,2) 和 (2,) 的 tf.Variable 类型数组相乘?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56196379/

相关文章:

python - Matplotlib.pyplot.hist() 很慢

python - 自动从csv文件中提取数据到特定的矩阵位置

python - 如何从文本文件中去除标点符号

python-3.x - pandas 数据帧中的跳转点 : the moment when the value in a column gets changed

python - 根据最大计数复制 Dataframe 中的数据

python - 提取标签之间的 HTML

macos - 该应用程序无法启动,因为它无法找到或加载Qt平台插件 "cocoa"

python 2D 数组,从 bool 掩码中切片/删除元素,并将结果保留为 2D 数组

python - 基于其他数组Python更新数组

python - 两个字符串之间的顺序交集