python - 优化sympy生成的代码

标签 python c optimization sympy

使用 SymPy 求导数(参见这个问题:https://math.stackexchange.com/questions/726104/apply-chain-rule-to-vector-function-with-chained-dot-and-cross-product),我想出了这段代码:

from sympy import *
from sympy.physics.mechanics import *
from sympy.printing import print_ccode
from sympy.utilities.codegen import codegen


x1, x2, x3 = symbols('x1 x2 x3')
y1, y2, y3 = symbols('y1 y2 y3')
z1, z2, z3 = symbols('z1 z2 z3')

u = ReferenceFrame('u')

u1=(u.x*x1 + u.y*y1 + u.z*z1)
u2=(u.x*x2 + u.y*y2 + u.z*z2)
u3=(u.x*x3 + u.y*y3 + u.z*z3)

s1=(u1-u2).normalize()
s2=(u2-u3).normalize()
v=cross(s1, s2)
f=dot(v,v)

df_dy2=diff(f, y2)


print_ccode(df_dy2, assign_to='df_dy2')


[(c_name, c_code), (h_name, c_header)] = codegen( ("df_dy2", df_dy2), "C", "test", header=False, empty=False)

print c_code

这产生了这种美:

#include "test.h"
#include <math.h>
double df_dy2(double x1, double x2, double x3, double y1, double y2, double y3, double z1, double z2, double z3) {
   return ((x1 - x2)*(y2 - y3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - (x2 - x3)*(y1 - y2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))))*(2*(x1 - x2)*(y1 - y2)*(y2 - y3)/(pow(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2), 3.0L/2.0L)*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) + 2*(x1 - x2)*(-y2 + y3)*(y2 - y3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*pow(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2), 3.0L/2.0L)) + 2*(x1 - x2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - 2*(x2 - x3)*pow(y1 - y2, 2)/(pow(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2), 3.0L/2.0L)*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - 2*(x2 - x3)*(y1 - y2)*(-y2 + y3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*pow(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2), 3.0L/2.0L)) + 2*(x2 - x3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2)))) + (-(x1 - x2)*(z2 - z3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) + (x2 - x3)*(z1 - z2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))))*(-2*(x1 - x2)*(y1 - y2)*(z2 - z3)/(pow(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2), 3.0L/2.0L)*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - 2*(x1 - x2)*(-y2 + y3)*(z2 - z3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*pow(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2), 3.0L/2.0L)) + 2*(x2 - x3)*(y1 - y2)*(z1 - z2)/(pow(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2), 3.0L/2.0L)*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) + 2*(x2 - x3)*(-y2 + y3)*(z1 - z2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*pow(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2), 3.0L/2.0L))) + ((y1 - y2)*(z2 - z3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - (y2 - y3)*(z1 - z2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))))*(2*pow(y1 - y2, 2)*(z2 - z3)/(pow(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2), 3.0L/2.0L)*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) + 2*(y1 - y2)*(-y2 + y3)*(z2 - z3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*pow(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2), 3.0L/2.0L)) - 2*(y1 - y2)*(y2 - y3)*(z1 - z2)/(pow(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2), 3.0L/2.0L)*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - 2*(-y2 + y3)*(y2 - y3)*(z1 - z2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*pow(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2), 3.0L/2.0L)) - 2*(z1 - z2)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))) - 2*(z2 - z3)/(sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2) + pow(z1 - z2, 2))*sqrt(pow(x2 - x3, 2) + pow(y2 - y3, 2) + pow(z2 - z3, 2))));
}

相同数字的 sqrts 和 pow 多次出现,可以计算一次以提高可读性和执行时间。但是我不知道怎么...

Q1:您知道让 sympy 自动执行此操作的方法吗?

问题 2:您是否知道使用其他工具对这段代码进行后处理的方法?

Q3:gcc可以在编译时优化吗?为什么?

最佳答案

这是我自己的基于 asmeurers 提示的小脚本:

def sympyToC( symname, symfunc ):
    tmpsyms = numbered_symbols("tmp")
    symbols, simple = cse(symfunc, symbols=tmpsyms)
    symbolslist = map(lambda x:str(x), list(symfunc.atoms(Symbol)) )
    symbolslist.sort()
    varstring=",".join( " double "+x for x in symbolslist )

    c_code = "double "+str(symname)+"("+varstring+" )\n"
    c_code +=  "{\n"
    for s in symbols:
        #print s
        c_code +=  "  double " +ccode(s[0]) + " = " + ccode(s[1]) + ";\n"
    c_code +=  "  r = " + ccode(simple[0])+";\n"
    c_code +=  "  return r;\n"
    c_code += "}\n"
    return c_code

对于python3.5+:

def sympyToC( symname, symfunc ):
    tmpsyms = numbered_symbols("tmp")
    symbols, simple = cse(symfunc, symbols=tmpsyms)
    symbolslist = sorted(map(lambda x:str(x), list(symfunc.atoms(Symbol))))
    varstring=",".join( " double "+x for x in symbolslist )
    c_code = "double "+str(symname)+"("+varstring+" )\n"
    c_code +=  "{\n"
    for s in symbols:
        #print s
        c_code +=  "  double " +ccode(s[0]) + " = " + ccode(s[1]) + ";\n"
    c_code +=  "  r = " + ccode(simple[0])+";\n"
    c_code +=  "  return r;\n"
    c_code += "}\n"
    return c_code

关于python - 优化sympy生成的代码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22665990/

相关文章:

python - 如何将表达式字符串拆分为 Python 中的列表?

python - 由于 Python 不支持 if 条件中的赋值,如何使复杂的 if 条件可维护?

linux - 优化打印计数器的循环

c - 从 char 中提取 char。例如 1 from 123

javascript - 比较对象数组,最佳方式

haskell - Haskell 中的缓存和显式并行性

python - 缝合最终尺寸和偏移

python - 在 numpy 二维数组的每一行中放置一个的快速方法

c - 从集合中查找最接近查询数字的两个不同数字的总和

c++ - 连接两个传感器 dht11 和 dht22 相同的 nodemcu 模块 esp-12e