scala - 如何在 Spark MLlib 中设置自定义损失函数

标签 scala apache-spark machine-learning regression apache-spark-mllib

我想使用我自己的损失函数而不是 linear regression model 的平方损失在 Spark MLlib 中。到目前为止,在文档中找不到任何提到它是否可能的部分。

最佳答案

TLDR; 使用自定义损失函数并不容易,因为您不能简单地将损失函数传递给 Spark 模型。但是,您可以轻松地为自己编写自定义模型。

长答案:
如果你看LinearRegressionWithSGD的代码你会看到:

class LinearRegressionWithSGD private[mllib] (
    private var stepSize: Double,
    private var numIterations: Int,
    private var regParam: Double,
    private var miniBatchFraction: Double)
  extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {

  private val gradient = new LeastSquaresGradient() #Loss Function
  private val updater = new SimpleUpdater()
  @Since("0.8.0")
  override val optimizer = new GradientDescent(gradient, updater) #Optimizer
    .setStepSize(stepSize)
    .setNumIterations(numIterations)
    .setRegParam(regParam)
    .setMiniBatchFraction(miniBatchFraction)

那么,我们来看看最小二乘损失函数是如何实现的here :

class LeastSquaresGradient extends Gradient {
  override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
    val diff = dot(data, weights) - label
    val loss = diff * diff / 2.0
    val gradient = data.copy
    scal(diff, gradient)
    (gradient, loss)
  }

  override def compute(
      data: Vector,
      label: Double,
      weights: Vector,
      cumGradient: Vector): Double = {
    val diff = dot(data, weights) - label
    axpy(diff, data, cumGradient)
    diff * diff / 2.0
  }
}

因此,您可以简单地编写一个类似 LeastSquaresGradient 的类并实现 compute 函数并在您的 LinearRegressionWithSGD 模型中使用它。

关于scala - 如何在 Spark MLlib 中设置自定义损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47291964/

相关文章:

scala - 在 IntelliJ Scala 工作表中使用 Apache Spark

scala - 如何在 sbt 下使用 Quasar 和 Scala?

http - Netty 中 HttpMessageDecoder.skipControlCharacters 上的 NullPointerException

java - 我可以确定正则表达式模式匹配的第一个字符集吗?

java - Spark saveAsTextFile() 导致 Mkdirs 无法为目录的一半创建

r - H2O 中的集成(随机森林)-多项分布

python - 在 Keras 中加载模型权重时出现问题

scala - 折叠 Action 在 Spark 中是如何工作的?

java - Spark 数据集 - 读取 CSV 并写入空输出

python - 如何使用带有 Tensorflow 的 Docker 在 Mac 终端上运行 Python 脚本?