java - 如何使用Java Math Commons CurveFitter?

标签 java math curve-fitting

如何使用 Math Commons CurveFitter 将函数拟合到一组数据?我被告知要将 CurveFitter 与 LevenbergMarquardtOptimizerParametricUnivariateFunction 一起使用,但我不知道在 ParametricUnivariateFunction 梯度和值方法中写什么。另外,写完之后,如何得到拟合的函数参数呢?我的功能:

public static double fnc(double t, double a, double b, double c){
  return a * Math.pow(t, b) * Math.exp(-c * t);
}

最佳答案

So, this is an old question, but I ran into the same issue recently, and ended up having to delve into mailing lists and the Apache Commons Math source code to figure it out.

This API is remarkably poorly documented, but in the current version of Apache Common Math (3.3+), there are two parts, assuming you have a single variable with multiple parameters: the function to fit with (which implements ParametricUnivariateFunction) and the curve fitter (which extends AbstractCurveFitter).

Function to Fit

  • public double value(double t, double... parameters)
    • Your equation. This is where you would put your fnc logic.
  • public double[] gradient(double t, double... parameters)
    • Returns an array of partial derivative of the above with respect to each parameters. This calculator may be helpful if (like me) you're rusty on your calculus, but any good computer algebra system can calculate these values.

Curve Fitter

  • protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points)
    • Sets up a bunch of boilerplate crap, and returns a least squares problem for the fitter to use.

Putting it all together, here's an example solution in your specific case:

import java.util.*;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.fitting.AbstractCurveFitter;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.apache.commons.math3.linear.DiagonalMatrix;

class MyFunc implements ParametricUnivariateFunction {
    public double value(double t, double... parameters) {
        return parameters[0] * Math.pow(t, parameters[1]) * Math.exp(-parameters[2] * t);
    }

    // Jacobian matrix of the above. In this case, this is just an array of
    // partial derivatives of the above function, with one element for each parameter.
    public double[] gradient(double t, double... parameters) {
        final double a = parameters[0];
        final double b = parameters[1];
        final double c = parameters[2];

        return new double[] {
            Math.exp(-c*t) * Math.pow(t, b),
            a * Math.exp(-c*t) * Math.pow(t, b) * Math.log(t),
            a * (-Math.exp(-c*t)) * Math.pow(t, b+1)
        };
    }
}

public class MyFuncFitter extends AbstractCurveFitter {
    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points) {
        final int len = points.size();
        final double[] target  = new double[len];
        final double[] weights = new double[len];
        final double[] initialGuess = { 1.0, 1.0, 1.0 };

        int i = 0;
        for(WeightedObservedPoint point : points) {
            target[i]  = point.getY();
            weights[i] = point.getWeight();
            i += 1;
        }

        final AbstractCurveFitter.TheoreticalValuesFunction model = new
            AbstractCurveFitter.TheoreticalValuesFunction(new MyFunc(), points);

        return new LeastSquaresBuilder().
            maxEvaluations(Integer.MAX_VALUE).
            maxIterations(Integer.MAX_VALUE).
            start(initialGuess).
            target(target).
            weight(new DiagonalMatrix(weights)).
            model(model.getModelFunction(), model.getModelFunctionJacobian()).
            build();
    }

    public static void main(String[] args) {
        MyFuncFitter fitter = new MyFuncFitter();
        ArrayList<WeightedObservedPoint> points = new ArrayList<WeightedObservedPoint>();

        // Add points here; for instance,
        WeightedObservedPoint point = new WeightedObservedPoint(1.0,
            1.0,
            1.0);
        points.add(point);

        final double coeffs[] = fitter.fit(points);
        System.out.println(Arrays.toString(coeffs));
    }
}

关于java - 如何使用Java Math Commons CurveFitter?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/11335127/

相关文章:

从 Firefox 74/75 发布 multipart/form-data 并在服务器端使用 Apache AJP 连接器时出现 java.io.IOException

java - 使用 Spring 框架在矩阵中直接注入(inject) double 值

java - 学习在线类(class),掌握职场技能

java - 旋转偏移坐标

c# - 长期存储和恢复变化

python - scipy非线性曲线拟合中的过度拟合

python - 尝试将标准的北极星改为更优雅的椭圆形

java - 如何将 JSON 响应包装在父对象中

python - 具有不等长度向量的多变量非线性回归

mysql - MySQL 的 DIV 运算符有 CEIL 版本吗?