python - Keras 中用于 idct 的自定义层

标签 python machine-learning keras computer-vision

我正在尝试在 Keras 中编写一个用于 IDCT(逆离散余弦变换)的自定义层,因为与 DCT 相比,Keras 中没有用于 IDCT 的内置函数。所以当我将图层写为:

model = Sequential()
model.add(Conv2D(512,1,activation='relu', input_shape= (8,8,64) ))
model.add(Lambda( lambda x: get_2d_idct_tensor(x) ) )

我的函数定义为:

def get_2d_idct_tensor(coefficients):
   return fftpack.idct(K.transpose(fftpack.idct(K.transpose(coefficients), norm='ortho')), norm='ortho')

我收到以下错误:

----> 9 model.add(Lambda( lambda x: get_2d_idct_tensor(x) ) )
 10 
 11 #model.add(Lambda(lambda x: K.tf.spectral.dct(K.transpose(K.tf.spectral.dct(K.transpose(x), type=2, norm='ortho')), norm='ortho'),input_shape=(8, 8, 512),output_shape=(8, 8, 1) ))

/usr/local/lib/python3.6/dist-packages/keras/models.py in add(self, layer)
520                           output_shapes=[self.outputs[0]._keras_shape])
521         else:
--> 522             output_tensor = layer(self.outputs[0])
523             if isinstance(output_tensor, list):
524                 raise TypeError('All layers in a Sequential model '

/usr/local/lib/python3.6/dist-packages/keras/engine/topology.py in __call__(self, inputs, **kwargs)
617 
618             # Actually call the layer, collecting output(s), mask(s), and shape(s).
--> 619             output = self.call(inputs, **kwargs)
620             output_mask = self.compute_mask(inputs, previous_mask)
621 

/usr/local/lib/python3.6/dist-packages/keras/layers/core.py in call(self, inputs, mask)
683         if has_arg(self.function, 'mask'):
684             arguments['mask'] = mask
--> 685         return self.function(inputs, **arguments)
686 
687     def compute_mask(self, inputs, mask=None):

<ipython-input-14-dae1f7021aae> in <lambda>(x)
  7 model.add(Conv2D(512,1,activation='relu', input_shape= (8,8,64) ))
  8 
----> 9 model.add(Lambda( lambda x: get_2d_idct_tensor(x) ) )
 10 
 11 #model.add(Lambda(lambda x: K.tf.spectral.dct(K.transpose(K.tf.spectral.dct(K.transpose(x), type=2, norm='ortho')), norm='ortho'),input_shape=(8, 8, 512),output_shape=(8, 8, 1) ))

<ipython-input-7-9ac404754077> in get_2d_idct_tensor(coefficients)
 12     """ Get 2D Inverse Cosine Transform of Image
 13     """
---> 14     return fftpack.idct(K.transpose(fftpack.idct(K.transpose(coefficients), norm='ortho')), norm='ortho')
 15 
 16 def get_reconstructed_image(img):

/usr/local/lib/python3.6/dist-packages/scipy/fftpack/realtransforms.py in idct(x, type, n, axis, norm, overwrite_x)
200     # Inverse/forward type table
201     _TP = {1:1, 2:3, 3:2}
--> 202     return _dct(x, _TP[type], n, axis, normalize=norm, overwrite_x=overwrite_x)
203 
204 

/usr/local/lib/python3.6/dist-packages/scipy/fftpack/realtransforms.py in _dct(x, type, n, axis, overwrite_x, normalize)
279 
280     """
--> 281     x0, n, copy_made = __fix_shape(x, n, axis, 'DCT')
282     if type == 1 and n < 2:
283         raise ValueError("DCT-I is not defined for size < 2")

/usr/local/lib/python3.6/dist-packages/scipy/fftpack/realtransforms.py in __fix_shape(x, n, axis, dct_or_dst)
224 
225 def __fix_shape(x, n, axis, dct_or_dst):
--> 226     tmp = _asfarray(x)
227     copy_made = _datacopied(tmp, x)
228     if n is None:

/usr/local/lib/python3.6/dist-packages/scipy/fftpack/basic.py in _asfarray(x)
125     already an array with a float dtype, and do not cast complex types to
126     real."""
--> 127     if hasattr(x, "dtype") and x.dtype.char in numpy.typecodes["AllFloat"]:
128         # 'dtype' attribute does not ensure that the
129         # object is an ndarray (e.g. Series class

AttributeError: 'DType' object has no attribute 'char'

有人可以解释一下这个错误是什么以及为什么会造成这个错误吗?我对 Keras 还很陌生,希望得到一些帮助来为我指明正确的方向。

预先感谢您的时间和帮助...

最佳答案

您正在运行一个需要张量上的 NumPy ndarray 的操作。不幸的是,这行不通。您需要使用张量运算符重新实现自定义操作。

话虽如此,直接使用 Tensorflow 中的函数也可以,例如从 import tensorflow 并在自定义层中使用这些函数可能会为您提供比单独 Keras 后端更多的函数。

关于python - Keras 中用于 idct 的自定义层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50559152/

相关文章:

tensorflow - 在 Keras 模型中使用 Tensorflow feature_column

python - 预期 conv1d_1_input 具有形状 (15, 512),但得到的数组具有形状 (4, 512)

python - 如何在 selenium 测试中将基本 url 作为参数

python - fatal error : Python. h:没有这样的文件或目录

python - Shutil.rmtree() 引发异常 WindowsError : Access is denied:

python - 如何检测数据框中某些值的条纹?

python - Keras 中的 EarlyStopping 会保存最好的模型吗?

python - 如何获得训练集和验证集的不同指标?

java - 在 Java TensorFlow 1.15 中使用 Python 构建的 TensorFlow 2.1.0 模型 |图表中没有名为 [input] 的操作

python - 如何使用 ML 模型和 FastAPI 处理多个用户的请求?