python - 具有线性回归模型的 fillna 从数据框 pandas 中的两列构建

标签 python pandas

我有这样的数据框:

    sepal length    sepal width   petal length  petal width     target
0      4.9              3.5            1.4          0.2         setosa
1      4.9              3.0            1.4          0.2         setosa
2      4.7              3.2            1.3          0.2         setosa
3      4.6              3.1            1.5          0.2         setosa
4      5.0              3.6            1.4          NaN         setosa
      ...   

我使用花瓣宽度和花瓣长度创建了 LinearRegression() 模型。现在我想使用我训练过的 linear_regression 模型来填充 NaN 值。

这是我尝试过的方法,它确实有效,但我很想知道是否有更有效的方法。

def fillna_linear_reg(length, width):
    if pd.isna(length):
        pred_length = lin_reg.predict([[width]]) 
        return pred_length[0][0]
    else:
        return length

iris_df["petal length (cm)"] = iris_df.apply(lambda x: fillna_linear_reg(x["petal length (cm)"], x["petal width (cm)"]), axis=1)

提前致谢!

最佳答案

是的,有一种更有效的方法。您可以使用预测并一次分配所有缺失值。尽可能避免使用 df.apply。它会降低性能,尤其是当与其他可矢量化函数一起使用时,例如 predict(或者甚至已经矢量化)(我假设是)sklearn 模型的方法。

def fillna_linear_reg(lin_reg, length, width):
    nan_mask = length.isna()
    pred_length = lin_reg.predict(width.loc[nan_mask])
    length.loc[nan_mask] = pred_length

fillna_linear_reg(
    lin_reg, iris_df.loc[:, "petal length (cm)"], iris_df.loc[:, "petal width (cm)"]
)

根据您用于训练的机器学习模块,您可能需要将 x 数据作为二维数组传递给 predict 方法,然后压缩回一维数组。如果是这样,您可以将包含预测的行替换为:

pred_length = np.squeeze(lin_reg.predict(np.atleast_2d(width.loc[nan_mask])))

如果您添加明确的形状信息,这当然可以简化。

关于python - 具有线性回归模型的 fillna 从数据框 pandas 中的两列构建,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65538171/

相关文章:

python - 所以我试图回答这个 : Given a list of ints, 如果数组在某处包含 3 旁边的 3 则返回 True

python - 扭曲的端口转发代理将数据发送回客户端

python - 计算一个日期在另一个日期的 x 个月内出现的次数

Python pandas 使用 read_hdf 和 HDFStore.select 从 HDF5 文件读取特定值

python - 将新列数组添加到 Pandas 数据框中

python - 将 python 数据框中的特定列转储为行

python - 在分析 Cython 代码时,什么是 `stringsource` ?

python - 在 Python 中打印字符串的一部分,同时跳过其他部分

python - 如何使用脚本重新创建 IPython 的 '--pylab' 选项的效果?

python - 使用多索引标准化 pandas DataFrame