如何提取 shap
摘要图的数值,以便可以在 dataframe
中查看数据?:
这是一个 MWE:
from sklearn.datasets import make_classification
from shap import Explainer, waterfall_plot, Explanation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
# Generate noisy Data
X, y = make_classification(n_samples=1000,
n_features=50,
n_informative=9,
n_redundant=0,
n_repeated=0,
n_classes=10,
n_clusters_per_class=1,
class_sep=9,
flip_y=0.2,
random_state=17)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
explainer = Explainer(model)
sv = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_train, plot_type="bar")
我试过了
np.abs(shap_values.values).mean(axis=0)
但是我得到的形状是(50,10)。如何获取每个功能的汇总值,然后根据功能重要性进行排序?
最佳答案
你已经做到了:
from sklearn.datasets import make_classification
from shap import Explainer, waterfall_plot, Explanation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from shap import summary_plot
# Generate noisy data
X, y = make_classification(
n_samples=1000,
n_features=50,
n_informative=9,
n_redundant=0,
n_repeated=0,
n_classes=10,
n_clusters_per_class=1,
class_sep=9,
flip_y=0.2,
random_state=17,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
explainer = Explainer(model)
sv = explainer.shap_values(X_test)
summary_plot(sv, X_train, plot_type="bar")
请注意,顶部有功能 3、29、34 等。
如果你这样做:
np.abs(sv).shape
(10, 250, 50)
您会发现您有 10 个类,用于 50 个特征的 250 个数据点。
如果你聚合,你会得到你需要的一切:
aggs = np.abs(sv).mean(1)
aggs.shape
(10, 50)
你可以画它:
sv_df = pd.DataFrame(aggs.T)
sv_df.plot(kind="barh",stacked=True)
如果它看起来仍然不熟悉,您可以重新排列和过滤:
sv_df.loc[sv_df.sum(1).sort_values(ascending=True).index[-10:]].plot(
kind="barh", stacked=True
)
结论:
sv_df
are aggregated SHAP values, as in summary plot, arranged as features per row and classes per column.
有帮助吗?
关于python-3.x - 如何从形状汇总图中提取实际值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73699610/