python - 用于 2D ALE 图的 Python PyALE 函数的值错误

标签 python random-forest iml pyale

我正在使用 Python 的 PyALE 函数创建累积局部效应图。我正在使用 RandomForestRegression 函数来构建模型。

我可以创建一维 ALE 图。但是,当我尝试使用相同的模型和训练数据创建 2D ALE 图时,出现值错误。

这是我的代码。

ale(training_data, model=model1, feature=["feature1", "feature2"])

我可以使用以下代码绘制特征 1 和特征 2 的一维 ALE 图。

ale(training_data, model=model1, feature=["feature1"], feature_type="continuous")

ale(training_data, model=model1, feature=["feature2"], feature_type="continuous")

数据框中的任何列都没有缺失值或无限值。

我在使用 2D ALE 绘图命令时遇到以下错误。

ValueError:输入包含 NaN、无穷大或对于 dtype('float32') 来说太大的值。

这是函数 https://pypi.org/project/PyALE/#description 的链接

我不确定为什么会收到此错误。我希望能得到一些帮助。

最佳答案

这个issue已在 PyALE 包的 v1.1.2 版本中解决。对于那些使用早期版本的用户,github 中的问题线程中提到的解决方法是重置提供给函数 ale 的数据集的索引。为了完整起见,这里有一个重现错误和解决方法的代码:

from PyALE import ale
import pandas as pd
import matplotlib.pyplot as plt
import random
from sklearn.ensemble import RandomForestRegressor

# get the raw diamond data (from R's ggplot2)
dat_diamonds = pd.read_csv(
    "https://raw.githubusercontent.com/tidyverse/ggplot2/master/data-raw/diamonds.csv"
)
X = dat_diamonds.loc[:, ~dat_diamonds.columns.str.contains("price")].copy()
y = dat_diamonds.loc[:, "price"].copy()

features = ["carat","depth", "table", "x", "y", "z"]

# fit the model
model = RandomForestRegressor(random_state=1345)
model.fit(X[features], y)

# sample the data
random.seed(1234)
indices = random.sample(range(X.shape[0]), 10000)
sampleData = X.loc[indices, :]

# get the effects.....
# This throws the error
ale_eff = ale(X=sampleData[features], model=model, feature=["z", "table"], grid_size=100)

# This will work, just reset the index with drop=True
ale_eff = ale(X=sampleData[features].reset_index(drop=True), model=model, feature=["z", "table"], grid_size=100)

关于python - 用于 2D ALE 图的 Python PyALE 函数的值错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67417939/

相关文章:

python - 为什么 S3(与 boto 和 django-storages 一起使用)甚至为公共(public)文件提供签名 url?

python - TensorFlow 随机森林与深度学习

android-studio - Android Studio 中的 .iml 文件是什么?

android-studio - Android Studio 中的 .iml 文件是什么?

python - Matplotlib 动画 - 如何将它们导出为在演示文稿中使用的格式?

python - 使用四个 CPU 运行一个 python 脚本

apache-spark - 如何获取 Spark MLlib 随机森林中每个树节点的记录数/类分布?

r - 为什么 R 中的 h2o.randomForest 比 randomForest 包做出更好的预测

Python:获取 PyObject 的字符串表示?