python - 使用 argmax 在 tensorflow 中切片张量

标签 python tensorflow

我想在tensorflow中做一个动态损失函数。我想计算信号 FFT 的能量,更具体地说,只计算最主要峰值周围大小为 3 的窗口。我无法在 TF 中实现,因为它会抛出很多错误,例如 StrideInvalidArgumentError(回溯见上文):Expected begin, end, and strides to be 1D equal size tensors,但取而代之的是形状 [1,64]、[1,64] 和 [1]。

我的代码是这样的:

self.spec = tf.fft(self.signal)
self.spec_mag = tf.complex_abs(self.spec[:,1:33])
self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32)
self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1)

因为我正在计算 64 位的批处理和数据维度也是 64 位,所以 self.signal 的形状是 (64,64)。我只想计算 FFT 的交流分量。由于信号是实值的,因此只有一半的频谱可以完成这项工作。因此,self.spec_mag 的形状是 (64,32)

此 fft 中的最大值位于 self.argm,其形状为 (64,1)

现在我想通过以下方式计算最大峰值附近 3 个元素的能量:self.spec_mag[self.argm-1:self.argm+2]

但是,当我运行代码并尝试获取 self.frac 的值时,我遇到了多个错误。

最佳答案

访问 argm 时似乎缺少索引。这里是1的固定版本,64位版本。

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:,1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

这里是扩展版本(batch, m, n)

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:, :, 1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

你可能想要修复函数名称,因为我在较新版本的 tensorflow 上编辑了这段代码。

关于python - 使用 argmax 在 tensorflow 中切片张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48680607/

相关文章:

TensorFlow:训练 BLSTM 时 CTC 损失没有减少

python - RNN : Get prediction from a text input after the model is trained

python - 基于多种条件的新列

python - 数组中每个值相对于其他值的最小距离

python - 如何在 PyCharm 中进行远程调试

python - tensorflow 2.0 中是否有 cudnnLSTM 或 cudNNGRU 替代方案

python - 类型错误 : '<' not supported between instances of 'torch.device' and 'int'

python - 在 python-riak 客户端 l 中将 erlang 模块和 erlang 函数发送到 mapreduce 面的确切方法是什么

python - 神经网络 - 输入标准化

tensorflow - 批量训练使用更新总和?或平均更新?