python - 如何在通过求解另一个分段函数获得的分段函数中对 ExprCondPair 进行排序?

标签 python sympy solver

让我们考虑以下示例:

import sympy as sym

x, y = sym.symbols(['x', 'y'])

cdf = sym.Piecewise((0, y < 0), 
                    (y, y < 1), 
                    (2*y - 1, y <= 2), 
                    (3, True))
eq = sym.Eq(x, cdf)
inverse = sym.solve(eq, y, rational=False)  # rational prevents buggy exception
print(inverse)

输出:

[Piecewise((x, x < 1), (nan, True)), 
 Piecewise((x/2 + 1/2, x/2 + 1/2 <= 2), (nan, True))]

可以使用以下函数轻松将其转换为单个 Piecewise:

from typing import List

def recreate_piecewise(functions: List[sym.Piecewise]) -> sym.Piecewise:
    """Collects Piecewise from a list of Piecewise functions"""
    return sym.Piecewise(*[piecewise.args[0]
                           for piecewise in functions])


print(recreate_piecewise(inverse))

输出:

Piecewise((x, x < 1), (x/2 + 1/2, x/2 + 1/2 <= 2))

该分段条件的根是13。这些根按从小到大的顺序排列。

当求解任何其他分段函数时,我希望它们的部分解决方案能够以相同的方式排序。但不幸的是,事实并非如此。

<小时/>

这可以通过以下示例来显示:

cdf = sym.Piecewise((0, y < 4.3), 
                    (y - 4.3, y < 12.9), 
                    (5*y - 55.9, y <= 13.5), 
                    (11.6, True))
eq = sym.Eq(x, cdf)
inverse = sym.solve(eq, y, rational=False)
print(recreate_piecewise(inverse))

输出:

Piecewise((x/5 + 11.18, x/5 + 11.18 <= 13.5), 
          (x + 4.3, x + 4.3 < 12.9))

这里的根是 11.68.6,这是不同的顺序。

<小时/>

问题:
如何获得始终按相同顺序排序的分段解决方案?

<小时/>

我尝试过的:
我已经实现了以下辅助函数。该解决方案有效,但不幸的是并不适用于所有情况。另外,我觉得我在这里使用了太多的解决方法,有些事情可以更容易完成。

from sympy.functions.elementary.piecewise import ExprCondPair


def replace_inequalities(expression: sym.Expr) -> sym.Expr:
    """Replaces <, <=, >, >= by == in expression"""
    conditions = [sym.Lt, sym.Le, sym.Gt, sym.Ge]
    for condition in conditions:
        expression = expression.replace(condition, sym.Eq)
    return expression


def piecewise_part_condition_root(expression_condition: ExprCondPair) -> float:
    """Returns a root of inequality part"""
    condition = expression_condition[1]
    equation = replace_inequalities(condition)
    return sym.solve(equation, x)[0]


def to_be_sorted(piecewise: sym.Function) -> bool:
    """Checks if elements of Piecewise have to be sorted"""
    first_line = piecewise.args[0]
    last_line = piecewise.args[-1]

    first_root = piecewise_part_condition_root(first_line)
    last_root = piecewise_part_condition_root(last_line)

    return last_root < first_root


def sort_piecewise(piecewise: sym.Piecewise) -> sym.Piecewise:
    """Inverts the order of elements in Piecewise"""
    return sym.Piecewise(*[part
                           for part in piecewise.args[::-1]])

对于第一个和第二个示例,它都可以工作。
第一个:

cdf = sym.Piecewise((0, y < 0), 
                    (y, y < 1), 
                    (2*y - 1, y <= 2), 
                    (3, True))
eq = sym.Eq(x, cdf)
inverse = sym.solve(eq, y, rational=False)
inverse = recreate_piecewise(inverse)

if to_be_sorted(inverse):
    inverse = sort_piecewise(inverse)
print(inverse)

输出:

Piecewise((x, x < 1), 
          (x/2 + 1/2, x/2 + 1/2 <= 2))

第二个:

cdf = sym.Piecewise((0, y < 4.3), 
                    (y - 4.3, y < 12.9), 
                    (5*y - 55.9, y <= 13.5), 
                    (11.6, True))
eq = sym.Eq(x, cdf)
inverse = sym.solve(eq, y, rational=False)
inverse = recreate_piecewise(inverse)

if to_be_sorted(inverse):
    inverse = sort_piecewise(inverse)
print(inverse)    

输出:

Piecewise((x + 4.3, x + 4.3 < 12.9), 
          (x/5 + 11.18, x/5 + 11.18 <= 13.5))

但是如果我举一个解决方案包含 LambertW 函数的示例,我的方法将失败:

def to_lower_lambertw_branch(*args) -> sym.Function:
    """
    Wraps the first argument from a given list of arguments
    as a lower branch of LambertW function.
    """
    return sym.LambertW(args[0], -1)


def replace_lambertw_branch(expression: sym.Function) -> sym.Function:
    """
    Replaces upper branch of LambertW function with the lower one.
    For details of the bug see:
    https://stackoverflow.com/questions/49817984/sympy-solve-doesnt-give-one-of-the-solutions-with-lambertw
    Solution is based on the 2nd example from:
    http://docs.sympy.org/latest/modules/core.html?highlight=replace#sympy.core.basic.Basic.replace
    """
    return expression.replace(sym.LambertW,
                              to_lower_lambertw_branch)


cdf = sym.Piecewise((0, y <= 0.0), 
                    ((-y - 1)*sym.exp(-y) + 1, y <= 10.0), 
                    (0.999500600772613, True))
eq = sym.Eq(x, cdf)
# Intermediate results are in inline comments
inverse = sym.solve(eq, y, rational=False)  # [Piecewise((-LambertW((x - 1)*exp(-1)) - 1, -LambertW((x - 1)*exp(-1)) - 1 <= 10.0), (nan, True))]
inverse = recreate_piecewise(inverse)  # Piecewise((-LambertW((x - 1)*exp(-1)) - 1, -LambertW((x - 1)*exp(-1)) - 1 <= 10.0))
inverse = replace_lambertw_branch(inverse)  # Piecewise((-LambertW((x - 1)*exp(-1), -1) - 1, -LambertW((x - 1)*exp(-1), -1) - 1 <= 10.0))

if to_be_sorted(inverse):  # -> this throws an error
    inverse = sort_piecewise(inverse)
print(inverse)    

在标记行上它会抛出错误:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-8-4ebf5c3828fb> in <module>()
     28 inverse = replace_lambertw_branch(inverse)  # Piecewise((-LambertW((x - 1)*exp(-1), -1) - 1, -LambertW((x - 1)*exp(-1), -1) - 1 <= 10.0))
     29 
---> 30 if to_be_sorted(inverse):  # -> this throws error
     31     inverse = sort_piecewise(inverse)
     32 print(inverse)

<ipython-input-5-d8df4b6ed407> in to_be_sorted(piecewise)
     22     last_line = piecewise.args[-1]
     23 
---> 24     first_root = piecewise_part_condition_root(first_line)
     25     last_root = piecewise_part_condition_root(last_line)
     26 

<ipython-input-5-d8df4b6ed407> in piecewise_part_condition_root(expression_condition)
     14     condition = expression_condition[1]
     15     equation = replace_inequalities(condition)
---> 16     return sym.solve(equation, x)[0]
     17 
     18 

~/.local/lib/python3.6/site-packages/sympy/solvers/solvers.py in solve(f, *symbols, **flags)
   1063     ###########################################################################
   1064     if bare_f:
-> 1065         solution = _solve(f[0], *symbols, **flags)
   1066     else:
   1067         solution = _solve_system(f, symbols, **flags)

~/.local/lib/python3.6/site-packages/sympy/solvers/solvers.py in _solve(f, *symbols, **flags)
   1632 
   1633     if result is False:
-> 1634         raise NotImplementedError('\n'.join([msg, not_impl_msg % f]))
   1635 
   1636     if flags.get('simplify', True):

NotImplementedError: 
No algorithms are implemented to solve equation -LambertW((x - 1)*exp(-1), -1) - 11

最佳答案

一种方法是以数值方式求解断点 (nsolve),这可能足以满足排序的目的。另一种方法是利用 CDF 是递增函数这一事实,根据 y 的值而不是 x 的值进行排序;也就是说,在中分段得到的不等式的右侧。在你的第二个例子中说明:

cdf = sym.Piecewise((0, y < 4.3), (y - 4.3, y < 12.9), (5*y - 55.9, y <= 13.5), (11.6, True))
inverse = sym.solve(sym.Eq(x, cdf), y, rational=False)
solutions = [piecewise.args[0] for piecewise in inverse]
solutions.sort(key=lambda case: case[1].args[1])
print(sym.Piecewise(*solutions))

打印

Piecewise((x + 4.3, x + 4.3 < 12.9), (x/5 + 11.18, x/5 + 11.18 <= 13.5))

这应该适用于任何递增函数,因为 y 值的递增顺序与 x 值的递增顺序相匹配。

关于python - 如何在通过求解另一个分段函数获得的分段函数中对 ExprCondPair 进行排序?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50428912/

相关文章:

python - 如何使用 Sympy 自动简化二值有限域上的表达式?

optimization - 优化工作调度MiniZinc代码——约束规划

php - PHP中的数独求解/生成算法

algorithm - 在我的递归方法中继续越界以导航并查找 4X4 矩阵中的所有字母组合

python - 计算事件之间的时间差

Pythonic 方式 - 三元 vs 或

python - 如何导出和保存链接的 Jupyter 笔记本?

python - 从 Sympy 中的级数展开中获取 n 阶系数

python matplotlib Canvas 未调整大小

matrix - sympy 复数矩阵的实部/虚部