python - 使用 SHAP 解释多标签目标 Tensorflow 模型

标签 python tensorflow shap

我有一个带有 2 个目标变量的 tensorflow 模型,我想按如下方式查看其 SHAP 值:

import pandas as pd
import tensorflow as tf
import shap

x_df = pd.DataFrame({'var1':[4, 6, 19, 8],
                         'var2':[7, 21, 5, 12],
                         'var3':[11, 15, 19, 5],
                         'var4':[8, 1, 16, 18]})

target_var = pd.DataFrame({'y1': [12, 4, 6, 8],
                           'y2': [11, 13, 9, 12]})


model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(4, activation='sigmoid'),
    tf.keras.layers.Dense(2)])
model.compile(optimizer='adam', loss="mse")
model.fit(x_df.values, target_var, epochs=1, batch_size=2)

explainer = shap.DeepExplainer(model, x_df)
shap_values = explainer.shap_values(x_df)
shap.summary_plot(shap_values, x_df)

我使用了几个教程来解释我可以将 tf 模型直接插入解释器并使用它,但是 .shap 值返回如下错误:

AttributeError: 'tuple' object has no attribute 'rank'

最佳答案

截至目前(2021 年 7 月),您无法解释多标签。输出必须是一维向量(秩为 1)。

documentation在不同的地方说,例如:

class Deep(Explainer):  

     def __init__(self, model, data, session=None, learning_phase_flags=None):   

    """ An explainer object for a differentiable model using a given background dataset....

     model : if framework == 'tensorflow', (input : [tf.Tensor], output : tf.Tensor)
             A pair of TensorFlow tensors (or a list and a tensor) that specifies the input and
            output of the model to be explained. Note that SHAP values are specific to a single
            output value, so the output tf.Tensor should be a single dimensional output (,1).

Deep 然后在包命名空间中作为 DeepExplainer 导入。

关于python - 使用 SHAP 解释多标签目标 Tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68373118/

相关文章:

python - 什么时候应该使用 Tensorflow 变量,什么时候应该使用 numpy 或 python 变量

r - 随机森林模型的形状图

python - 如何以概率输出 Shap 值并从二元分类器制作 force_plot

python套接字协议(protocol)不支持

tensorflow - 多维 lstm tensorflow

tensorflow - 在 keras 中,如何使用自定义对象克隆模型?

python-3.x - 如何将绘图(由 shap_values 生成)保存到 png?

python - 如何在 Debug模式下调用PySpark?

python - 在 pyplot.table 中插入图像/对象

java - Python多线程执行多个jar文件需要更长的时间