scipy - 我可以使用什么算法来识别此散点图中的线?

标签 scipy linear-regression

我正在创建一个程序来比较音频文件,该程序使用与此处描述的算法类似的算法 http://www.ee.columbia.edu/~dpwe/papers/Wang03-shazam.pdf .我正在绘制被比较的两首歌曲之间的匹配时间并找到该图的最小二乘线。 href= http://imgur.com/fGu7jhX&yOeMSK0是匹配文件的示例图。图太乱,最小二乘回归线虽然有明显的线,但相关系数不高。我可以使用什么其他算法来识别这条线?

最佳答案

这是一个有趣的问题,但它一直很安静。也许这个答案
会触发更多的事件。

用于识别集合中具有任意斜率和截距的线
的点,霍夫变换将是一个很好的起点。为您的音频
应用程序,但是,看起来斜率应该始终为 1,因此您不必
需要霍夫变换的全部一般性。

相反,您可以将问题视为聚类差异之一 x - y ,其中 xy是保存点的 x 和 y 坐标的向量。

一种方法是计算 x - y 的直方图。 .接近位于斜率为 1 的同一条线上的点将在直方图中的同一条柱中存在差异。计数最大的 bin 对应于近似对齐的最大点集合。这种方法要处理的一个问题是选择直方图箱的边界。一个错误的选择可能会导致应该组合在一起的点被分成相邻的 bin。

一个简单的蛮力方法是想象一个具有给定宽度的对角线窗口,在 (x,y) 平面上从左向右滑动。一条线的最佳候选对应于包含最多点的窗口的位置。这类似于 x - y 的直方图,但不是有一组不相交的垃圾箱,而是有重叠的垃圾箱,每个点一个。所有的 bin 具有相同的宽度,每个点决定了 bin 的左边缘。

函数count_diag_groups在下面的代码中进行该计算。对于每个点,当窗口的左边缘在该点上时,它计算对角窗口中有多少点。一条线的最佳候选者是点数最多的窗口。这是脚本生成的图。顶部是数据的散点图。底部是相同的散点图,突出显示了最佳候选点。

Plot generated by the script

这种方法的一个很好的特点是只有一个参数,即窗口宽度。一个不太好的特性是它的时间复杂度为 O(n**2),其中 n 是点数。肯定有时间复杂度更好的算法可以做类似的事情;您链接到的文章对此进行了讨论。然而,要判断替代方案的质量,将需要更具体的规范,说明线路识别必须有多“好”或多稳健。

import numpy as np
import matplotlib.pyplot as plt


def count_diag_groups(x, y, width):
    """
    Returns a list of arrays.  The length of the list is the same
    as the length of x.  The k-th array holds the indices into x
    (and y) of a set of points that are in a "diagonal" window with
    the given width whose left edge includes the point (x[k], y[k]).
    """
    d = x - y
    result = []
    for i in range(d.size):
        delta = d - d[i]
        neighbors = np.where((delta >= 0) & (delta <= width))[0]
        result.append(neighbors)
    return result


def generate_demo_data():
    # Generate some data.
    np.random.seed(123)
    xmin = 0
    xmax = 100
    ymin = 0
    ymax = 25
    nrnd = 175
    xrnd = xmin + (xmax - xmin)*np.random.rand(nrnd)
    yrnd = ymin + (ymax - ymin)*np.random.rand(nrnd)
    n = 25
    xx = xmin + 0.1*(xmax - xmin) + ymax*np.random.rand(n)
    yy = (xx - xx.min()) + 0.2*np.random.randn(n)
    x = np.concatenate((xrnd, xx))
    y = np.concatenate((yrnd, yy))
    return x, y


def plot_result(x, y, width, selection):
    xmin = x.min()
    xmax = x.max()
    ymin = y.min()
    ymax = y.max()

    xsel = x[selection]
    ysel = y[selection]
    # Plot...
    plt.figure(1)
    plt.clf()
    ax = plt.subplot(2,1,1)
    plt.plot(x, y, 'o', mfc='b', mec='b', alpha=0.5)
    plt.xlim(xmin - 1, xmax + 1)
    plt.ylim(ymin - 1, ymax + 1)

    plt.subplot(2,1,2, sharex=ax, sharey=ax)
    plt.plot(x, y, 'o', mfc='b', mec='b', alpha=0.5)
    plt.plot(xsel, ysel, 'o', mfc='w', mec='w')
    plt.plot(xsel, ysel, 'o', mfc='r', mec='r', alpha=0.65)
    xi = np.array([xmin, xmax])
    d = x - y
    yi1 = xi - d[imax]
    yi2 = yi1 - width
    plt.plot(xi, yi1, 'r-', alpha=0.25)
    plt.plot(xi, yi2, 'r-', alpha=0.25)
    plt.xlim(xmin - 1, xmax + 1)
    plt.ylim(ymin - 1, ymax + 1)

    plt.show()

if __name__ == "__main__":
    x, y = generate_demo_data()

    # Find a selection of points that are close to being aligned
    # with a slope of 1.
    width = 0.75
    r = count_diag_groups(x, y, width)

    # Find the largest group.
    sz = np.array(list(len(f) for f in r))
    imax = sz.argmax()
    # k holds the indices of the selected points.
    selection = r[imax]

    plot_result(x, y, width, selection)

关于scipy - 我可以使用什么算法来识别此散点图中的线?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20296030/

相关文章:

python-3.x - Python不准确的曲线拟合

machine-learning - 超过 2 theta 值的梯度下降

c# - 使用 C# 的线性回归梯度下降

apache-spark - Spark Linear Regression With SGD 对特征缩放非常敏感

python - Scikit-learn的fetch不下载数据集

python - python/opencv:如何通过给定图像上的点识别循环?

scipy - 将一组 Pandas Series reshape 为 DataFrame 并填充缺失值

python - 使用 Sklearn 进行梯度提升

python - 在 Keras 中训练多元回归模型时损失值非常大

python - 如何从 sklearn LinearRegression 导出线性回归公式