python - 根据新数据点更新预先训练的深度学习模型

标签 python deep-learning classification conv-neural-network pytorch

考虑 ImageNet 上的图像分类示例,如何使用新数据点更新预训练模型。 我已经加载了预训练的模型。我有一个新数据点,与之前训练模型的原始数据的分布完全不同。因此,我想借助新数据点来更新/微调模型。该如何去做呢?有人可以帮我做吗?我使用pytorch 0.4.0进行实现,在GPU Tesla K40C上运行。

最佳答案

如果您不想更改分类器的输出(即类的数量),那么您可以简单地使用新的示例图像继续训练模型,假设它们被 reshape 为与预训练模型相同的形状接受。

另一方面,如果您想更改预训练模型中的类数,则可以用新层替换最后一个全连接层,并在新样本上仅训练该特定层。以下是此案例的示例代码,来自 PyTorch's autograd mechanics notes :

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

关于python - 根据新数据点更新预先训练的深度学习模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53624766/

相关文章:

python - 致命标志解析错误 :Unknown command line flag 'logtostderr'

python - 随时间传播卫星目录的有效方法

python - 我无法使用 python-apt 获取 'required_changes'

python - 使用 Sequential API 从 Keras 自动编码器中提取编码/解码模型

python - 如何使用 numpy 比较多维数组的对角相对元素

python - PyTorch:state_dict 和 parameters() 有什么区别?

machine-learning - pytorch SGD 的默认批量大小是多少?

python - 使用 train_test_split 拆分数据时与之后加载 csv 文件时的精度不同

python - 使用 scikit-learn 进行多标签文本分类,使用哪些分类器?

python - 不同阈值的特异性(与 sklearn.metrics. precision_recall_curve 相同)