python - 具有对角权重矩阵的自定义层

标签 python tensorflow keras keras-layer

我想实现一个具有稀疏输入层的分类器。我的数据大约有 60 个维度,我想检查特征重要性。为此,我希望第一层有一个对角权重矩阵(我想对其应用 L1 内核正则化器),所有非对角线都应该是不可训练的零。因此,每个输入 channel 都是一对一的连接,密集层将混合输入变量。我检查了Specify connections in NN (in keras)Custom connections between layers Keras 。后一个我无法使用,因为 Lambda 层不引入可训练的权重。

但是这样的事情不会影响实际的权重矩阵:

class MyLayer(Layer):
def __init__(self, output_dim,connection, **kwargs):
    self.output_dim = output_dim
    self.connection=connection
    super(MyLayer, self).__init__(**kwargs)

def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                                  shape=(input_shape[1], self.output_dim),
                                  initializer='uniform',
                                  trainable=True)
    self.kernel=tf.linalg.tensor_diag_part(self.kernel)
    self.kernel=tf.linalg.tensor_diag(self.kernel)
    super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

def call(self, x):
    return K.dot(x, self.kernel)

def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

当我训练模型并打印权重时,我没有得到第一层的对角矩阵。

我做错了什么?

最佳答案

不太确定你到底想做什么,因为对我来说,对角线是方阵的东西,这意味着你的层输入和输出维度应该保持不变。

不管怎样,我们先来谈谈方阵的情况。我认为有两种方法可以实现对角线全零值的权重矩阵。

方法一:仅在概念上遵循方阵思想,并使用可训练的权重向量实现该层,如下所示。

# instead of writing y = K.dot(x,W), 
# where W is the weight NxN matrix with zero values of the diagonal.
# write y = x * w, where w is the weight vector 1xN

方法 2:使用默认的 Dense 层,但使用您自己的 constraint .

# all you need to create a mask matrix M, which is a NxN identity matrix
# and you can write a contraint like below
class DiagonalWeight(Constraint):
    """Constrains the weights to be diagonal.
    """
    def __call__(self, w):
        N = K.int_shape(w)[-1]
        m = K.eye(N)
        w *= m
        return w

当然,您应该使用Dense( ..., kernel_constraint=DiagonalWeight())

关于python - 具有对角权重矩阵的自定义层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53744518/

相关文章:

python - Keras Conv2D : filters vs kernel_size

python - 关于keras.utils.Sequence的澄清

python - Google对象检测API——使用faster_rcnn_resnet101_coco模型进行训练

python - 对来自不同列表的相同索引的元素求和

Tensorflow:批量大小 > 1 时无法过度拟合训练数据

python - Keras 功能 API 多输入模型

python - NLTK 正则表达式分词器在正则表达式中不能很好地处理小数点

python - 查找特定数字的完整行

tensorflow - 我如何使用自己的图像在 tensorFlow 中训练我的 CNN 神经网络

android - 无法测试和部署用于推理的 deeplabv3-mobilenetv2 tensorflow-lite 分割模型