python - tensorflow 中的线性模型

标签 python tensorflow linear-algebra

我试图在 Tensorflow 中生成一个简单的线性模型。这是代码...

N        = 400
features = 100
nSteps   = 1000

data = (np.random.randn(N, features), np.random.randint(0, 2, N))

W = tf.placeholder(tf.float32, shape=(features,1), name='W')
b = tf.placeholder(tf.float32, shape=(features,1), name='b')
d = tf.constant(data[0], dtype=tf.float32)

result = tf.add( tf.matmul(d, W), b)

事实证明,b的尺寸可能存在一些问题,但由于某种原因,据我所知,它们都没有问题......

不知道为什么会抛出错误。有人可以帮忙吗?

注意:

result = tf.matmul(d, W)

这没关系。

我检查了结果的形状,与b的形状相同。不太确定可能是什么问题。

最佳答案

在线性模型(即输出层中的一个单元)中,b 应该是一个标量。

从数学上讲,对于单个观察,您有:结果 = WX + b,其中维度 W [1 x 特征],X [功能×1]。那么,WX 是标量。因此 b 应该是一个标量。

因此,您应该将 b 更改为以下内容,以获得正确的线性模型并计算尺寸:

b = tf.placeholder(tf.float32, shape=(1,1), name='b')

关于python - tensorflow 中的线性模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46417690/

相关文章:

python - TensorFlow 如何将构建与执行分开以实现数据并行

algorithm - 我如何衡量某些词的趋势,比如 Twitter?

python - 使用 NumPy 计算特征向量的说明

python - 迭代多个 python 列表上的元素

python - 排除 Pyinstaller 中的文件

java - Jython sys 模块中缺少函数

mysql - 在Tensorflow中读取Mysql数据库

python - 从 Perl 调用 Python 模块

c++ - tensorflow 教程中有关量化的错误

matlab - 在有限域域中查找 A.x = b 中 x 的 "all solutions"