python - 查找 xy 数据点图与 numpy 的所有交集?

标签 python numpy intersection points

我正在分析循环拉伸(stretch)试验的数据。作为输入,使用了巨大的 x 和 y 值列表。 为了描述 Material 是硬化还是软化,我需要获取每个循环循环的蓝色斜率。

tensile_test

slope

爬到坡的低点是 child 节,爬到坡的高点就是挑战。

the_challenge

到目前为止,我已经采用了这种方法,用低于每个循环的局部最大值的几个点切出循环,并根据 硬编号 点数制作红线。红线的近似值由 poly1d(polyfit(x1,x2,1)) 生成,然后使用 fsolve 获得交点。然而,它并不总是正确工作,因为点的分布并不总是相同的。

问题是如何正确定义两条(红色)相交线的间隔。上图中是 3 个实验以及平均斜率。我花了几天时间试图为每个循环找到 4 个最近的点,确定这不是最佳方法。最后,我在 stackowerflow 上结束了。

期望的输出是带有交点近似坐标的列表——如果你想玩,here是曲线的数据 (0,[[xvals],[yvals]])。 Theese 可以很容易地阅读

import csv
import sys
csv. field_size_limit(sys.maxsize)     

csvfile = 'data.csv'
tc_data = {}
for key, val in csv.reader(open(csvfile, "r")):
    tc_data[key] = val
for key in tc_data:
  tc = eval(tc_data[key])

x = tc[0]
y = tc[1]

最佳答案

这可能有点矫枉过正,但找到交点的正确方法是,一旦您将曲线分割成 block ,就是查看第一个 block 中的任何线段是否与第二个 block 中的任何线段相交。

我要为自己制作一些简单的数据,prolate cycloid 的一部分,我将找到 y 坐标从增加到减少的翻转位置,类似于 here :

a, b = 1, 2
phi = np.linspace(3, 10, 100)
x = a*phi - b*np.sin(phi)
y = a - b*np.cos(phi)
y_growth_flips = np.where(np.diff(np.diff(y) > 0))[0] + 1

plt.plot(x, y, 'rx')
plt.plot(x[y_growth_flips], y[y_growth_flips], 'bo')
plt.axis([2, 12, -1.5, 3.5])
plt.show()

enter image description here

如果您有两个路段,一个从点 P0P1,另一个从点 Q0Q1,你可以通过求解向量方程 P0 + s*(P1-P0) = Q0 + t*(Q1-Q0) 找到它们的交点,如果两个线段确实相交st 都在 [0, 1] 中。对所有分割市场进行尝试:

x_down = x[y_growth_flips[0]:y_growth_flips[1]+1]
y_down = y[y_growth_flips[0]:y_growth_flips[1]+1]
x_up = x[y_growth_flips[1]:y_growth_flips[2]+1]
y_up = y[y_growth_flips[1]:y_growth_flips[2]+1]

def find_intersect(x_down, y_down, x_up, y_up):
    for j in xrange(len(x_down)-1):
        p0 = np.array([x_down[j], y_down[j]])
        p1 = np.array([x_down[j+1], y_down[j+1]])
        for k in xrange(len(x_up)-1):
            q0 = np.array([x_up[k], y_up[k]])
            q1 = np.array([x_up[k+1], y_up[k+1]])
            params = np.linalg.solve(np.column_stack((p1-p0, q0-q1)),
                                     q0-p0)
            if np.all((params >= 0) & (params <= 1)):
                return p0 + params[0]*(p1 - p0)

>>> find_intersect(x_down, y_down, x_up, y_up)
array([ 6.28302264,  1.63658676])

crossing_point = find_intersect(x_down, y_down, x_up, y_up)
plt.plot(crossing_point[0], crossing_point[1], 'ro')
plt.show()

enter image description here

在我的系统上,这可以每秒处理大约 20 个交叉点,这不是超快,但可能足以不时分析图表。您可以通过向量化 2x2 线性系统的解决方案来加快速度:

def find_intersect_vec(x_down, y_down, x_up, y_up):
    p = np.column_stack((x_down, y_down))
    q = np.column_stack((x_up, y_up))
    p0, p1, q0, q1 = p[:-1], p[1:], q[:-1], q[1:]
    rhs = q0 - p0[:, np.newaxis, :]
    mat = np.empty((len(p0), len(q0), 2, 2))
    mat[..., 0] = (p1 - p0)[:, np.newaxis]
    mat[..., 1] = q0 - q1
    mat_inv = -mat.copy()
    mat_inv[..., 0, 0] = mat[..., 1, 1]
    mat_inv[..., 1, 1] = mat[..., 0, 0]
    det = mat[..., 0, 0] * mat[..., 1, 1] - mat[..., 0, 1] * mat[..., 1, 0]
    mat_inv /= det[..., np.newaxis, np.newaxis]
    import numpy.core.umath_tests as ut
    params = ut.matrix_multiply(mat_inv, rhs[..., np.newaxis])
    intersection = np.all((params >= 0) & (params <= 1), axis=(-1, -2))
    p0_s = params[intersection, 0, :] * mat[intersection, :, 0]
    return p0_s + p0[np.where(intersection)[0]]

是的,它很困惑,但它确实有效,而且速度快了 100 倍:

find_intersect(x_down, y_down, x_up, y_up)
Out[67]: array([ 6.28302264,  1.63658676])

find_intersect_vec(x_down, y_down, x_up, y_up)
Out[68]: array([[ 6.28302264,  1.63658676]])

%timeit find_intersect(x_down, y_down, x_up, y_up)
10 loops, best of 3: 66.1 ms per loop

%timeit find_intersect_vec(x_down, y_down, x_up, y_up)
1000 loops, best of 3: 375 us per loop

关于python - 查找 xy 数据点图与 numpy 的所有交集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/17928452/

相关文章:

python - Numpy:循环替代、优化

javascript - 检查一条线和一个网格是否在 three.js 中相交

java - 正则表达式 Java。为什么要使用交集?

2d 游戏 : fire at a moving target by predicting intersection of projectile and unit

python按相同键值计数对字典列表进行排序

python - 生成包含随机 boolean 值的大型 numpy 数组的内存有效方法

python - 如何访问该模型类中的数据?

Python - 如何像行一样读取/解析 csv?

python - 在 Basemap 上绘制文本字符串代替 Python 中的点

python - 格式化包含非 ascii 字符的列