我正在学习 pytorch,要对这里以这种方式创建的数据进行基本的线性回归:
from sklearn.datasets import make_regression
x, y = make_regression(n_samples=100, n_features=1, noise=15, random_state=42)
y = y.reshape(-1, 1)
print(x.shape, y.shape)
plt.scatter(x, y)
我知道使用 tensorflow 这段代码可以解决:model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=1, activation='linear', input_shape=(x.shape[1], )))
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05), loss='mse')
hist = model.fit(x, y, epochs=15, verbose=0)
但我需要知道 pytorch 等价物会是什么样子,我试图做的是:# Model Class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Linear(1,1)
def forward(self, x):
x = self.linear(x)
return x
def predict(self, x):
return self.forward(x)
model = Net()
loss_fn = F.mse_loss
opt = torch.optim.SGD(modelo.parameters(), lr=0.05)
# Funcao para treinar
def fit(num_epochs, model, loss_fn, opt, train_dl):
# Repeat for given number of epochs
for epoch in range(num_epochs):
# Train with batches of data
for xb, yb in train_dl:
# 1. Generate predictions
pred = model(xb)
# 2. Calculate Loss
loss = loss_fn(pred, yb)
# 3. Campute gradients
loss.backward()
# 4. Update parameters using gradients
opt.step()
# 5. Reset the gradients to zero
opt.zero_grad()
# Print the progress
if (epoch+1) % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# Training
fit(200, model, loss_fn, opt, data_loader)
但是模型没有学到任何东西,我不知道我还能做什么。输入/输出尺寸为 (1/1)
最佳答案
数据集
首先,您应该定义 torch.utils.data.Dataset
import torch
from sklearn.datasets import make_regression
class RegressionDataset(torch.utils.data.Dataset):
def __init__(self):
data = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=42)
self.x = torch.from_numpy(data[0]).float()
self.y = torch.from_numpy(data[1]).float()
def __len__(self):
return len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
它转换 numpy
数据到 PyTorch 的 tensor
内__init__
并将数据转换为 float
( numpy
默认为 double
而 PyTorch 的默认值为 float
以使用更少的内存)。除此之外,它只会返回
tuple
特征和各自的回归目标。合身
差不多了,但是您必须使模型的输出变平(如下所述)。
torch.nn.Linear
将返回形状张量 (batch, 1)
而你的目标是形状 (batch,)
. flatten()
将删除不必要的 1
尺寸。# 2. Calculate Loss
loss = criterion(pred.flatten(), yb)
模型这就是你真正需要的:
model = torch.nn.Linear(1, 1)
任何层都可以直接调用,不需要forward
和简单模型的继承。打电话
剩下的几乎没问题,你只需要创建
torch.utils.data.DataLoader
并传递我们数据集的实例。什么 DataLoader
是不是有问题__getitem__
的 dataset
多次创建指定大小的批次(还有一些其他有趣的事情,但这就是想法):dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
model = torch.nn.Linear(1, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=3e-4)
fit(5000, model, criterion, optimizer, dataloader)
还要注意我用过 torch.nn.MSELoss()
,当我们传递对象时,在这种情况下它看起来比函数更好。全码
为了使它更容易:
import torch
from sklearn.datasets import make_regression
class RegressionDataset(torch.utils.data.Dataset):
def __init__(self):
data = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=42)
self.x = torch.from_numpy(data[0]).float()
self.y = torch.from_numpy(data[1]).float()
def __len__(self):
return len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
# Funcao para treinar
def fit(num_epochs, model, criterion, optimizer, train_dl):
# Repeat for given number of epochs
for epoch in range(num_epochs):
# Train with batches of data
for xb, yb in train_dl:
# 1. Generate predictions
pred = model(xb)
# 2. Calculate Loss
loss = criterion(pred.flatten(), yb)
# 3. Compute gradients
loss.backward()
# 4. Update parameters using gradients
optimizer.step()
# 5. Reset the gradients to zero
optimizer.zero_grad()
# Print the progress
if (epoch + 1) % 10 == 0:
print(
"Epoch [{}/{}], Loss: {:.4f}".format(epoch + 1, num_epochs, loss.item())
)
dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
model = torch.nn.Linear(1, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=3e-4)
fit(5000, model, criterion, optimizer, dataloader)
你应该四处走走 0.053
损失左右,因人而异 noise
或其他更难/更容易回归任务的参数。
关于python - tensorflow 线性回归的pytorch等价物是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63830441/