python - 在Python中高精度地找到由(x,y)数据给出的两条曲线的交点

标签 python algorithm numpy scipy

我有两个数据集:(x, y1) 和 (x, y2)。我想找到这两条曲线相交的位置。目标类似于这个问题:Intersection of two graphs in Python, find the x value:

但是,那里描述的方法只能找到最近数据点的交点。我想找到比原始数据间距精度更高的曲线交点。一种选择是简单地重新插值到更精细的网格。这可行,但精度取决于我为重新插值选择的点数,这是任意的,需要在精度和效率之间进行权衡。

或者,我可以使用 scipy.optimize.fsolve 找到数据集的两个样条插值的精确交集。这很好用,但它不能轻易找到多个交点,需要我为交点提供合理的猜测,并且可能无法很好地扩展。 (最终,我想找到几千组(x,y1,y2)的交集,所以一个高效的算法会很好。)

这是我目前所拥有的。有什么改进的想法吗?

import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate, scipy.optimize

x  = np.linspace(1, 4, 20)
y1 = np.sin(x)
y2 = 0.05*x

plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

interp1 = scipy.interpolate.InterpolatedUnivariateSpline(x, y1)
interp2 = scipy.interpolate.InterpolatedUnivariateSpline(x, y2)

new_x = np.linspace(x.min(), x.max(), 100)
new_y1 = interp1(new_x)
new_y2 = interp2(new_x)
idx = np.argwhere(np.diff(np.sign(new_y1 - new_y2)) != 0)
plt.plot(new_x[idx], new_y1[idx], 'ro', ms=7, label='Nearest data-point method, with re-interpolated data')

def difference(x):
    return np.abs(interp1(x) - interp2(x))

x_at_crossing = scipy.optimize.fsolve(difference, x0=3.0)
plt.plot(x_at_crossing, interp1(x_at_crossing), 'cd', ms=7, label='fsolve method')

plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

plt.savefig('curve crossing.png', dpi=200)
plt.show()

enter image description here

最佳答案

最佳(也是最有效)的答案可能取决于数据集及其采样方式。但是,许多数据集的一个很好的近似是它们在数据点之间几乎是线性的。因此,我们可以通过原帖中显示的“最近数据点”方法找到交叉点的大致位置。然后,我们可以使用线性插值细化最近的两个数据点之间的交点位置。

这种方法非常快,并且适用于 2D numpy 数组,以防您想一次计算多条曲线的交叉点(就像我想在我的应用程序中做的那样)。

(我从“How do I compute the intersection point of two lines in Python?”中借用代码进行线性插值。)

from __future__ import division 
import numpy as np
import matplotlib.pyplot as plt

def interpolated_intercept(x, y1, y2):
    """Find the intercept of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the first line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y

        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
    xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
    return xc,yc

def main():
    x  = np.linspace(1, 4, 20)
    y1 = np.sin(x)
    y2 = 0.05*x

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

    # new method!
    xc, yc = interpolated_intercept(x,y1,y2)
    plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

    plt.savefig('curve crossing.png', dpi=200)
    plt.show()

if __name__ == '__main__': 
    main()

Curve crossing

2018-12-13 更新: 如果需要找到多个拦截,这里是执行此操作的代码的修改版本:

from __future__ import division 
import numpy as np
import matplotlib.pyplot as plt

def interpolated_intercepts(x, y1, y2):
    """Find the intercepts of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the first line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y

        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idxs = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    xcs = []
    ycs = []

    for idx in idxs:
        xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
        xcs.append(xc)
        ycs.append(yc)
    return np.array(xcs), np.array(ycs)

def main():
    x  = np.linspace(1, 10, 50)
    y1 = np.sin(x)
    y2 = 0.02*x

    plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
    plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

    plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

    # new method!
    xcs, ycs = interpolated_intercepts(x,y1,y2)
    for xc, yc in zip(xcs, ycs):
        plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')


    plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

    plt.savefig('curve crossing.png', dpi=200)
    plt.show()

if __name__ == '__main__': 
    main()

``` enter image description here

关于python - 在Python中高精度地找到由(x,y)数据给出的两条曲线的交点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42464334/

相关文章:

python - 将 Python 可执行文件部署到 Azure Service Fabric

python 正则表达式不使用 re.match 和 re.MULTILINE 标志匹配文件内容

python - 带有标题的 Pandas Dataframe 的嵌套列表

c# - 简单拼图算法

algorithm - 如何压缩排序单词列表?

python - 我是否正确地将傅里叶变换应用于此数据框? [EXAFS X 射线吸收数据框]

python - Scrapy和 Selenium : How to call a method in a loop

algorithm - 我如何在进化算法中表示百分比?

python - numpy pandas 的构建轮失败

python - 使用 Numpy 的多项式的系数误差