python - PyTorch 二元分类不学习

标签 python pytorch

我声明我是 PyTorch 的新手。我编写了这个简单的二元分类程序。我还创建了包含两列随机值的 CSV,其中“ok”列仅当其他两个值包含在我同时决定的两个值之间时,其值为 1。示例:

diam_int,diam_est,ok
37.782,125.507,0
41.278,115.15,1
42.248,115.489,1
29.582,113.141,0
37.428,107.247,0
32.947,123.233,0
37.146,121.537,0
38.537,110.032,0
26.553,113.752,0
27.369,121.144,0
41.632,108.178,0
27.655,111.279,0
29.779,109.268,0
43.695,115.649,1
44.587,116.126,0

在我看来,一切都做得正确,损失实际上降低了(在许多时期后它会略有回升,但我不认为这是一个问题),但是当我尝试在训练后用样本批处理测试我的网络时对于训练集,我得到的预测始终低于 0.5(因此估计输出始终为 0),且趋势完全随机。

with torch.no_grad():
        pred = net(trainSet[10])
        trueVal = ySet[10]
        for i in range(len(trueVal)):
            print(trueVal[i], pred[i])

这是我的网络类(class):

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    
    def __init__(self) :
        super().__init__()
        self.fc1 = nn.Linear(2, 32)
        self.fc2 = nn.Linear(32, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return torch.sigmoid(x)

这是我的主要类(class):

import torch
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd

from net import Net 

df = pd.read_csv("test.csv")
y = torch.Tensor(df["ok"])
ySet = torch.split(y, 32)
df.drop(["ok"], axis=1, inplace=True)
data = F.normalize(torch.Tensor(df.values), dim=1)
trainSet = torch.split(data, 32)

net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.001)
lossFunction = torch.nn.BCELoss()
EPOCHS = 300

for epoch in range(EPOCHS):
    for i, X in enumerate(trainSet):
        optimizer.zero_grad()
        output = net(X)
        target = ySet[i].reshape(-1, 1)
        loss = lossFunction(output, target)
        loss.backward()
        optimizer.step()

    if epoch % 20 == 0:
        print(loss)

我做错了什么?预先感谢您的帮助

最佳答案

您的模型不适合。将纪元数增加到(例如)3000 可以使模型根据您展示的示例进行完美预测。

然而,经过这么多轮之后,模型可能过度拟合。一个好的做法是使用验证数据(将生成的数据分为训练集和验证集),并检查每个时期的验证损失。当验证损失开始增加时,您开始过度拟合并停止训练。

关于python - PyTorch 二元分类不学习,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71192650/

相关文章:

python - 在 Excel 中发布保存数据

python - 任务之间的 Airflow 延迟

pytorch - 如何将 LIME 与 PyTorch 集成?

machine-learning - 如何在 Pytorch 中使用 torch.nn.Sequential 实现我自己的 ResNet?

python - PyTorch 不能 pickle lambda

python - 针对特定情况实现 SmoothL1Loss

python - Firebase 使用 float 作为键

python - Pandas:散点图,其点的大小由一列的唯一值与另一列的相应值的大小决定

python - reshape Pytorch 张量

python - Django 1.8 : Throwing ImportError: No module named 'MySQLdb'