python - 如何在 pytorch 中改变神经网络的权重

标签 python deep-learning pytorch genetic-algorithm

我正在使用 pytorch 尝试遗传算法,我正在寻找一种更有效的方法来改变网络的权重(对它们进行小的修改)

现在我有一个次优的解决方案,我循环遍历参数并应用随机修改。

child_agent = network()
for param in child_agent.parameters():
        if len(param.shape) == 4:  # weights of Conv2D
            for i0 in range(param.shape[0]):
                for i1 in range(param.shape[1]):
                    for i2 in range(param.shape[2]):
                        for i3 in range(param.shape[3]):
                            param[i0][i1][i2][i3] += mutation_power * np.random.randn()
        elif len(param.shape) == 2:  # weights of linear layer
            for i0 in range(param.shape[0]):
                for i1 in range(param.shape[1]):
                    param[i0][i1] += mutation_power * np.random.randn()
        elif len(param.shape) == 1:  # biases of linear layer or conv layer
            for i0 in range(param.shape[0]):
                param[i0] += mutation_power * np.random.randn()

此解决方案与我的架构绑定(bind),如果我决定添加更多层,则需要重新编码。有什么方法可以更有效、更清洁地做到这一点吗?最好无论我的网络架构如何,它都能正常工作。

谢谢

最佳答案

pytorchnumpy 是面向 tensor 的,例如您可以对多维数组类对象中包含的多个项目进行操作。

您可以将整个代码更改为这一行:

import torch

child_agent = network()
for param in child_agent.parameters():
    param.data += mutation.power * torch.randn_like(param)

randn_like ( docs here ) 创建与 param 形状相同的随机正态张量。

此外,如果此参数需要 grad(它可能需要),您应该修改它的 data 字段。

MCVE :

import torch

mutation_power = 0.4

child_agent = torch.nn.Sequential(
    torch.nn.Conv2d(1, 3, 3, padding=1), torch.nn.Linear(10, 20)
)

for param in child_agent.parameters():
    param.data += mutation_power * torch.randn_like(param)

关于python - 如何在 pytorch 中改变神经网络的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63951120/

相关文章:

machine-learning - 在 TensorFlow 中实现多层 RNN 最有效的方法是什么?

machine-learning - keras中train_on_batch()有什么用?

machine-learning - `Torch`(即 `nn.SpatialConvolution` )中的卷积层和 `Pytorch`(即 `torch.nn.Conv2d` )中的卷积层有什么不同

pytorch - 使用 2d 张量索引 3d 张量

tensorflow - 哪个Loss function & Metrics更适合多标签分类?二元或分类交叉熵,为什么?

python - 标记化和编码数据集使用过多 RAM

python - python中元组的点积

python - 如何将更长的列表附加到数据框

Python 效率/优化项目 Euler #5 示例

python - 为什么可变实体不能是字典键?