我有两个数据集:(x, y1) 和 (x, y2)。我想找到这两条曲线相交的位置。目标类似于这个问题:Intersection of two graphs in Python, find the x value:
但是,那里描述的方法只能找到最近数据点的交点。我想找到比原始数据间距精度更高的曲线交点。一种选择是简单地重新插值到更精细的网格。这可行,但精度取决于我为重新插值选择的点数,这是任意的,需要在精度和效率之间进行权衡。
或者,我可以使用 scipy.optimize.fsolve
找到数据集的两个样条插值的精确交集。这很好用,但它不能轻易找到多个交点,需要我为交点提供合理的猜测,并且可能无法很好地扩展。 (最终,我想找到几千组(x,y1,y2)的交集,所以一个高效的算法会很好。)
这是我目前所拥有的。有什么改进的想法吗?
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate, scipy.optimize
x = np.linspace(1, 4, 20)
y1 = np.sin(x)
y2 = 0.05*x
plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')
idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')
interp1 = scipy.interpolate.InterpolatedUnivariateSpline(x, y1)
interp2 = scipy.interpolate.InterpolatedUnivariateSpline(x, y2)
new_x = np.linspace(x.min(), x.max(), 100)
new_y1 = interp1(new_x)
new_y2 = interp2(new_x)
idx = np.argwhere(np.diff(np.sign(new_y1 - new_y2)) != 0)
plt.plot(new_x[idx], new_y1[idx], 'ro', ms=7, label='Nearest data-point method, with re-interpolated data')
def difference(x):
return np.abs(interp1(x) - interp2(x))
x_at_crossing = scipy.optimize.fsolve(difference, x0=3.0)
plt.plot(x_at_crossing, interp1(x_at_crossing), 'cd', ms=7, label='fsolve method')
plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')
plt.savefig('curve crossing.png', dpi=200)
plt.show()
最佳答案
最佳(也是最有效)的答案可能取决于数据集及其采样方式。但是,许多数据集的一个很好的近似是它们在数据点之间几乎是线性的。因此,我们可以通过原帖中显示的“最近数据点”方法找到交叉点的大致位置。然后,我们可以使用线性插值细化最近的两个数据点之间的交点位置。
这种方法非常快,并且适用于 2D numpy 数组,以防您想一次计算多条曲线的交叉点(就像我想在我的应用程序中做的那样)。
(我从“How do I compute the intersection point of two lines in Python?”中借用代码进行线性插值。)
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
def interpolated_intercept(x, y1, y2):
"""Find the intercept of two curves, given by the same x data"""
def intercept(point1, point2, point3, point4):
"""find the intersection between two lines
the first line is defined by the line between point1 and point2
the first line is defined by the line between point3 and point4
each point is an (x,y) tuple.
So, for example, you can find the intersection between
intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)
Returns: the intercept, in (x,y) format
"""
def line(p1, p2):
A = (p1[1] - p2[1])
B = (p2[0] - p1[0])
C = (p1[0]*p2[1] - p2[0]*p1[1])
return A, B, -C
def intersection(L1, L2):
D = L1[0] * L2[1] - L1[1] * L2[0]
Dx = L1[2] * L2[1] - L1[1] * L2[2]
Dy = L1[0] * L2[2] - L1[2] * L2[0]
x = Dx / D
y = Dy / D
return x,y
L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
L2 = line([point3[0],point3[1]], [point4[0],point4[1]])
R = intersection(L1, L2)
return R
idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
return xc,yc
def main():
x = np.linspace(1, 4, 20)
y1 = np.sin(x)
y2 = 0.05*x
plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')
idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')
# new method!
xc, yc = interpolated_intercept(x,y1,y2)
plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')
plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')
plt.savefig('curve crossing.png', dpi=200)
plt.show()
if __name__ == '__main__':
main()
2018-12-13 更新: 如果需要找到多个拦截,这里是执行此操作的代码的修改版本:
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
def interpolated_intercepts(x, y1, y2):
"""Find the intercepts of two curves, given by the same x data"""
def intercept(point1, point2, point3, point4):
"""find the intersection between two lines
the first line is defined by the line between point1 and point2
the first line is defined by the line between point3 and point4
each point is an (x,y) tuple.
So, for example, you can find the intersection between
intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)
Returns: the intercept, in (x,y) format
"""
def line(p1, p2):
A = (p1[1] - p2[1])
B = (p2[0] - p1[0])
C = (p1[0]*p2[1] - p2[0]*p1[1])
return A, B, -C
def intersection(L1, L2):
D = L1[0] * L2[1] - L1[1] * L2[0]
Dx = L1[2] * L2[1] - L1[1] * L2[2]
Dy = L1[0] * L2[2] - L1[2] * L2[0]
x = Dx / D
y = Dy / D
return x,y
L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
L2 = line([point3[0],point3[1]], [point4[0],point4[1]])
R = intersection(L1, L2)
return R
idxs = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
xcs = []
ycs = []
for idx in idxs:
xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
xcs.append(xc)
ycs.append(yc)
return np.array(xcs), np.array(ycs)
def main():
x = np.linspace(1, 10, 50)
y1 = np.sin(x)
y2 = 0.02*x
plt.plot(x, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
plt.plot(x, y2, marker='o', mec='none', ms=4, lw=1, label='y2')
idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
plt.plot(x[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')
# new method!
xcs, ycs = interpolated_intercepts(x,y1,y2)
for xc, yc in zip(xcs, ycs):
plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')
plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')
plt.savefig('curve crossing.png', dpi=200)
plt.show()
if __name__ == '__main__':
main()
关于python - 在Python中高精度地找到由(x,y)数据给出的两条曲线的交点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42464334/