python - 在 Pytorch 中强制执行 nn.Parameter(矩阵)参数中的结构

标签 python matrix parameters pytorch

在 PyTorch 库中,可以定义具有某些初始值的神经网络参数 nn.Parameter,例如,

some_param = nn.Parameter(data=torch.rand(4,4))

就我而言,我想对此参数强制执行某些结构。例如,考虑严格的下三角形式(在矩阵参数的情况下),因此 some_param 的形式为:

[ 0   0   0   0 ]
[a21  0   0   0 ]
[a31 a32  0   0 ]
[a41 a42 a43  0 ]

但是,如果我使用初始化参数

some_param = nn.Parameter(data=torch.tril(torch.rand(4,4),-1))

那么对角线及其上方的零在训练期间可以变得非零...我如何确保该参数在训练期间保留其结构?

最佳答案

您可以显式仅存储 6 个参数,然后即时构建完整矩阵:

class StructuredParameter(nn.Module):
  def __init__(self, ...):
    self.explicit_p = nn.Parameter(torch.rand(6,))
    self.tril_ind = torch.tril_indice(4, 4, -1)

  def forward(self, ...):
    # before using some_param - create it
    some_param = torch.zeros(4, 4)
    some_params[self.tril_ind] = self.explicit_p

    # use some_param ...

关于python - 在 Pytorch 中强制执行 nn.Parameter(矩阵)参数中的结构,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71382161/

相关文章:

python - 替换 XML 文件中的 python 字符串

Python cx_Oracle 绑定(bind)变量

python - 在 pygame 中使用player.jump 玩家仅跳跃一次

matlab - 如何使用矩阵作为输入来训练 Matlab 神经网络?

javascript - 使用矩阵围绕中心旋转 Svg 路径

c# - 将参数传递给基类构造函数

C++构造函数默认参数

python - 相当于 'in' 用于比较两个 Numpy 数组

c - 行优先与列优先混淆

java - 如何确保强制参数传递到 Spring MVC Controller 方法中?