python - 扩展 Cython 类时,__cinit__() 恰好需要 2 个位置参数

标签 python scikit-learn runtime-error subclass cython

我想扩展 scikit-learn 的 ClassificationCriterion 类,该类在内部模块中定义为 Cython 类 sklearn.tree._criterion 。我想在 Python 中执行此操作,因为通常我无法访问 sklearn 的 pyx/pxd 文件(因此我无法 cimport 它们)。但是,当我尝试扩展 ClassificationCriterion 时,收到错误 TypeError: __cinit__() 恰好需要 2 个位置参数(给定 0 个)。下面的 MWE 重现了该错误,并显示该错误发生在 __new__ 之后但 __init__ 之前。

有什么方法可以像这样扩展 Cython 类吗?

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree._criterion import ClassificationCriterion

class MaxChildPrecision(ClassificationCriterion):
    def __new__(self, *args, **kwargs):
        print('new')
        super().__new__(MaxChildPrecision, *args, **kwargs)

    def __init__(self, *args, **kwargs):
        print('init')
        super(MaxChildPrecision).__init__(*args, **kwargs)

clf = DecisionTreeClassifier(criterion=MaxChildPrecision())

最佳答案

有两个问题。首先ClassificationCriterion requires two specific arguments to its constructor that you aren't passing it 。您必须弄清楚这些参数代表什么并将它们传递给基类。

其次,还有 Cython 问题。如果我们看 the description of how to use __cinit__然后我们看到:

Any arguments passed to the constructor will be passed to both the __cinit__() method and the __init__() method. If you anticipate subclassing your extension type in Python, you may find it useful to give the __cinit__() method * and ** arguments so that it can accept and ignore extra arguments. Otherwise, any Python subclass which has an init() with a different signature will have to override __new__() as well as __init__()

不幸的是,sklearn 的作者没有提供 *** 参数,所以你必须重写 __new__。像这样的东西应该有效:

class MaxChildPrecision(ClassificationCriterion):
    def __init__(self,*args, **kwargs):
        pass

    def __new__(cls,*args,**kwargs):
        # I have NO IDEA if these arguments make sense!
        return super().__new__(cls,n_outputs=5,
                           n_classes=np.ones((2,),dtype=np.int))

我将必要的参数传递给__new__中的ClassificationCriterion,并在__init__中处理我认为合适的其余部分。我不需要调用基类 __init__ (因为基类没有定义 __init__)。

关于python - 扩展 Cython 类时,__cinit__() 恰好需要 2 个位置参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47624000/

相关文章:

Python 将回车解码为换行

python - 在sklearn中学习SVM后如何使用dual_coef_param?

tensorflow - 如何使用 tensorflow 数据集训练 sklearn 模型?

python - sklearn 用户的 R 插入符号

Java 运行时错误

python - getopt() 不强制要求参数?

python - Matplotlib - 显示唯一值频率的条形图

python - 为什么 matplotlib 中的图例不能正确显示颜色?

运行暗网检测的 OpenCV 未知层类型

java - 无法修复 android.content.res.Resources$NotFoundException