我正在尝试使用 numpy 操作来实现tensorflow的 conv1d,暂时忽略步幅和填充。我以为我在previous question之后就明白了但今天意识到,在处理宽度大于 1 的内核时,我仍然没有得到正确的答案。
所以现在我尝试使用 tflearn 作为模板,因为它为我计算内核形状。现在我明白卷积可以计算为矩阵乘法,我正在尝试相应地使用内核矩阵,但我没有得到与 tflearn 相同的答案。检查源代码是相当不透明的,因为它只是调用tensorflow的专用编译实现。
这是我到目前为止所得到的:
inp = np.arange(10).reshape(1,10,1).astype(np.float32)
filters = 2
width = 3
z = tflearn.conv_1d(inp, filters, width, strides=1, padding='same', bias=False)
s = tf.Session()
s.run(tf.global_variables_initializer())
z1, w = s.run([z, z.W])
print('tflearn output shape', z1.shape)
print('tflearn kernel shape', w.shape)
print('numpy matmul shape', (inp @ w).shape)
这表明 tflearn 内核将宽度作为额外维度插入到开头:
tflearn output shape (1, 10, 2)
tflearn kernel shape (3, 1, 1, 2)
numpy matmul shape (3, 1, 10, 2)
因此,我得到的结果具有额外的 3
维度。好吧,那么我如何正确地减少它以获得与 tensorflow 相同的答案呢?我尝试对这个维度求和,但不正确:
print('tflearn output:')
print(z1)
print('numpy output:')
print(np.sum(inp @ w, axis=0))
给予,
tflearn output:
[[[-0.02252221 0.24712706]
[ 0.49539018 1.0828717 ]
[ 0.0315876 2.0945265 ]
[-0.43221498 3.1061814 ]
[-0.89601755 4.117836 ]
[-1.3598201 5.129491 ]
[-1.823623 6.141146 ]
[-2.2874253 7.152801 ]
[-2.7512276 8.164455 ]
[-2.989808 6.7048397 ]]]
numpy output:
[[[ 0. 0. ]
[-0.46380258 1.0116549 ]
[-0.92760515 2.0233097 ]
[-1.3914077 3.0349646 ]
[-1.8552103 4.0466194 ]
[-2.319013 5.0582743 ]
[-2.7828155 6.069929 ]
[-3.2466178 7.0815845 ]
[-3.7104206 8.093239 ]
[-4.174223 9.104893 ]]]
这显然是不同的。 z.W
当然已初始化为随机值,因此这些数字也是随机的,但我正在寻找使它们等于 z1
的 numpy 计算,因为它们正在执行相同的内核。显然它并不像inp @ w
那么简单。
谢谢。
最佳答案
好吧,抱歉,经过一番思考,我已经回答了我自己的问题...这就是我在上一个问题中试图介绍的滑动窗口操作的来源:
y = (inp @ w)
y[0,:,:-2,:] + y[1,:,1:-1,:] + y[2,:,2:,:]
给予,
array([[[ 0.49539018, 1.0828717 ],
[ 0.0315876 , 2.0945265 ],
[-0.43221498, 3.1061814 ],
[-0.89601755, 4.117836 ],
[-1.3598201 , 5.129491 ],
[-1.823623 , 6.141146 ],
[-2.2874253 , 7.152801 ],
[-2.7512276 , 8.164455 ]]], dtype=float32)
这等于 z1
忽略第一行和最后一行,这正是我对 3 点卷积的期望。
编辑:但如果有人可以提出一种更简洁/有效的方式来表达滑动窗口,我将非常感激。我从我之前的问题中想到,即使是滑动窗口也可以在矩阵乘法中考虑在内,所以不幸的是需要显式地编写索引逻辑。
关于python - 使用 numpy 运算实现 conv1d,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59553815/