python - Numpy:制作四元数乘法的批处理版本

标签 python numpy linear-algebra

我改造了下面的函数

def quaternion_multiply(quaternion0, quaternion1):
    """Return multiplication of two quaternions.

    >>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
    >>> numpy.allclose(q, [-44, -14, 48, 28])
    True

    """
    x0, y0, z0, w0 = quaternion0
    x1, y1, z1, w1 = quaternion1
    return numpy.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)

到批处理版本

def quat_multiply(self, quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, 1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, 1)

    result = np.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=np.float64)
    return np.transpose(np.squeeze(result))

此函数处理形状为 (?,4) 的 quaternion1 和 quaternion0。现在我希望该函数可以处理任意数量的维度,例如 (?,?,4)。如何做到这一点?

最佳答案

只需将 axis-=-1 传递给 np.split 即可沿最后一个轴拆分,从而获得所需的行为。

并且由于您的数组具有烦人的大小为 1 的尾随维度,而不是沿着新的维度堆叠,然后将那个维度挤出,您可以简单地连接它们,再次沿着(最后一个)axis=-1:

def quat_multiply(self, quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, axis=-1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, axis=-1)
    return np.concatenate(
        (x1*w0 + y1*z0 - z1*y0 + w1*x0,
         -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
         -x1*x0 - y1*y0 - z1*z0 + w1*w0),
        axis=-1)

请注意,使用这种方法,您不仅可以乘以任意维数的相同形状的四元数堆栈:

>>> a = np.random.rand(6, 5, 4)
>>> b = np.random.rand(6, 5, 4)
>>> quat_multiply(None, a, b).shape
(6, 5, 4)

但是你也得到了很好的广播,它允许你将一堆四元数与一个四元数相乘,而不必摆弄维度:

>>> a = np.random.rand(6, 5, 4)
>>> b = np.random.rand(4)
>>> quat_multiply(None, a, b).shape
(6, 5, 4)

或者用最少的操作在一行中完成两个堆栈之间的所有叉积:

>>> a = np.random.rand(6, 4)
>>> b = np.random.rand(5, 4)
>>> quat_multiply(None, a[:, None], b).shape
(6, 5, 4)

关于python - Numpy:制作四元数乘法的批处理版本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40915069/

相关文章:

python - 将一个以小写字母开头的元素连接到列表的前一个元素

Python:通用导入

python - 如何从元组迭代创建 numpy.ndarray

python - 在 Python 中查找矩阵变为奇异的值

python - 如何从矩阵中找到线性独立的行

python - matplotlib.pyplot 不会忘记以前的图 - 我如何刷新/刷新?

python - 非阻塞套接字,错误总是

python - scrapy 蜘蛛如何将值返回给另一个蜘蛛

python - Numpy 嵌套结构化数组引用

python - 创建随机线性无关向量