python - Python中RK4算法错误

标签 python ode

我正在求解非线性薛定谔 (NLS) 方程:

(1): i*u_t + 0.5*u_xx + abs(u)^2 * u = 0

应用傅立叶变换后,变为:

(2): uhat_t = -0.5*i*k^2 * uhat + i * fft(abs(u)^2 * u)

其中uhatu的傅立叶变换。上面的方程(2)是一个明确的 IVP,可以通过四阶 Runge-Kutta 方法求解。这是我求解方程(2)的代码:

import numpy as np
import math
from matplotlib import pyplot as plt
from matplotlib import animation


#----- Numerical integration of ODE via fixed-step classical Runge-Kutta -----

def RK4(TimeSpan,uhat0,nt):
    h = float(TimeSpan[1]-TimeSpan[0])/nt
    print h 
    t = np.empty(nt+1)
    print np.size(t)        # nt+1 vector
    w = np.empty(t.shape+uhat0.shape,dtype=uhat0.dtype)
    print np.shape(w)       # nt+1 by nx matrix
    t[0]   = TimeSpan[0]    
    w[0,:] = uhat0          # enter initial conditions in w
    for i in range(nt):
        t[i+1]   = t[i]+h   
        w[i+1,:] = RK4Step(t[i], w[i,:],h)
    return w

def RK4Step(t,w,h):
    k1 = h * uhatprime(t,w)
    k2 = h * uhatprime(t+0.5*h, w+0.5*k1*h)
    k3 = h * uhatprime(t+0.5*h, w+0.5*k2*h)
    k4 = h * uhatprime(t+h,     w+k3*h)
    return w + (k1+2*k2+2*k3+k4)/6.

#----- Constructing the grid and kernel functions -----
L   = 40
nx  = 512
x   = np.linspace(-L/2,L/2, nx+1)
x   = x[:nx]  

kx1 = np.linspace(0,nx/2-1,nx/2)
kx2 = np.linspace(1,nx/2,  nx/2)
kx2 = -1*kx2[::-1]
kx  = (2.* np.pi/L)*np.concatenate((kx1,kx2))

#----- Define RHS -----
def uhatprime(t, uhat):
    u = np.fft.ifft(uhat)
    z = -(1j/2.) * (kx**2) * uhat + 1j * np.fft.fft((abs(u)**2) * u)
    return z

#------ Initial Conditions -----
u0      = 1./np.cosh(x)#+1./np.cosh(x-0.4*L)
uhat0   = np.fft.fft(u0)

#------ Solving for ODE -----
TimeSpan = [0,10.]
nt       = 100
uhatsol  = RK4(TimeSpan,uhat0,nt) 
print np.shape(uhatsol)
print uhatsol[:6,:]

我打印了迭代的前6步,错误发生在第6步,我不明白为什么会发生这种情况。 6步的结果是:

nls.py:44: RuntimeWarning: overflow encountered in square
  z = -(1j/2.) * (kx**2) * uhat + 1j * np.fft.fft((abs(u)**2) * u)
(101, 512)
[[  4.02123859e+01 +0.00000000e+00j  -3.90186082e+01 +3.16101312e-14j
    3.57681095e+01 -1.43322854e-14j ...,  -3.12522653e+01 +1.18074871e-13j
    3.57681095e+01 -1.20028987e-13j  -3.90186082e+01 +1.62245217e-13j]
 [  4.02073593e+01 +2.01061092e+00j  -3.90137309e+01 -1.95092228e+00j
    3.57636385e+01 +1.78839803e+00j ...,  -3.12483587e+01 -1.56260675e+00j
    3.57636385e+01 +1.78839803e+00j  -3.90137309e+01 -1.95092228e+00j]
 [  4.01015488e+01 +4.02524105e+00j  -3.89110557e+01 -3.90585271e+00j
    3.56695007e+01 +3.58076808e+00j ...,  -3.11660830e+01 -3.12911766e+00j
    3.56695007e+01 +3.58076808e+00j  -3.89110557e+01 -3.90585271e+00j]
 [  3.98941946e+01 +6.03886019e+00j  -3.87098310e+01 -5.85991079e+00j
    3.54849686e+01 +5.37263725e+00j ...,  -3.10047495e+01 -4.69562640e+00j
    3.54849686e+01 +5.37263725e+00j  -3.87098310e+01 -5.85991079e+00j]
 [  3.95847537e+01 +8.04663227e+00j  -3.84095149e+01 -7.80840256e+00j
    3.52095058e+01 +7.15970026e+00j ...,  -3.07638375e+01 -6.25837011e+00j
    3.52095070e+01 +7.15970040e+00j  -3.84095155e+01 -7.80840264e+00j]
 [  1.47696187e+22 -7.55759947e+22j   1.47709575e+22 -7.55843420e+22j
    1.47749677e+22 -7.56093844e+22j ...,   1.47816312e+22 -7.56511230e+22j
    1.47749559e+22 -7.56093867e+22j   1.47709516e+22 -7.55843432e+22j]]

在第6步,迭代的值是疯狂的。 Aslo,这里发生了溢出错误。

有什么帮助吗?谢谢!!!!

最佳答案

第一次解析时出现了两个明显的错误。

  1. (这不是一个错误,ifft 确实是 fft 的完全逆。在其他库中可能不是这种情况。)

  2. 在 RK4 步骤中,您必须为因子 h 确定一个位置。要么(作为例子,其他类推)

     k2 = f(t+0.5*h, y+0.5*h*k1)
    

     k2 = h*f(t+0.5*h, y+0.5*k1)
    
<小时/>

但是,纠正这些点只会延迟爆炸。存在动态爆炸的可能性并不奇怪,这是从三次项中可以预料到的。一般来说,如果所有项都是线性或次线性的,则只能预期“缓慢”的指数增长。

为了避免“非物理”奇点,必须按与 Lipschitz 常数成反比的方式缩放步长。由于这里的 Lipschitz 常数的大小为 u^2,因此必须动态适应。我发现在区间 [0,1](即 h=0.001)中使用 1000 个步骤不会出现奇点。对于区间 [0,10] 上的 10 000 步,这仍然成立。

<小时/>

更新 原始方程中没有时间导数的部分是自伴的,这意味着函数的范数平方(对绝对值平方的积分)保留在精确的形式中解决方案。因此,总体情况是一个高维“旋转”(参见对已经在 3 维中演化的模式的固体运动学的讨论)。

现在的问题是,函数的某些部分可能会以如此小的半径或如此高的速度“旋转”,以至于时间步长代表旋转的很大一部分甚至多次旋转。这很难用数值方法来捕捉,因此需要减少时间步长。这种现象的总称是“刚性微分方程”:显式龙格-库塔方法不适合刚性问题。

<小时/>

更新2:雇用 methods used before ,可以使用以下方法求解解耦频域中的线性部分(请注意,这些都是逐分量数组运算)

vhat = exp( 0.5j * kx**2 * t) * uhat

它允许具有更大步长的稳定解决方案。与处理 KdV 方程一样,线性部分 i*u_t+0.5*u_xx=0 在 DFT 下解耦为

i*uhat_t-0.5*kx**2*uhat=0 

因此可以很容易地求解成相应的指数函数

exp( -0.5j * kx**2 * t).

然后通过设置使用常量的变化来解决完整的方程

uhat = exp( -0.5j * kx**2 * t)*vhat. 

这减轻了 kx 较大组件的刚度负担,但三次方仍然存在。因此,如果步长变大,数值解就会在很少的步长内爆炸。

下面的工作代码


import numpy as np
import math
from matplotlib import pyplot as plt
from matplotlib import animation


#----- Numerical integration of ODE via fixed-step classical Runge-Kutta -----

def RK4Step(odefunc, t,w,h):
    k1 = odefunc(t,w)
    k2 = odefunc(t+0.5*h, w+0.5*k1*h)
    k3 = odefunc(t+0.5*h, w+0.5*k2*h)
    k4 = odefunc(t+h,     w+k3*h)
    return w + (k1+2*k2+2*k3+k4)*(h/6.)

def RK4Stream(odefunc,TimeSpan,uhat0,nt):
    h = float(TimeSpan[1]-TimeSpan[0])/nt
    print(f"step size {h}") 
    w = uhat0
    t = TimeSpan[0]
    while True:
        w = RK4Step(odefunc, t, w, h)
        t = t+h
        yield t,w

#----- Constructing the grid and kernel functions -----
L   = 40
nx  = 512
x, dx = np.linspace(-L/2,L/2, nx+1, retstep=True)
x   = x[:-1]  # periodic boundary, last same as first

kx  = 2*np.pi*np.fft.fftfreq(nx, dx) # angular frequencies for the fft bins

def uhat2vhat(t,uhat):
    return np.exp( 0.5j * (kx**2) *t) * uhat

def vhat2uhat(t,vhat):
    return np.exp(- 0.5j * (kx**2) *t) * vhat

#----- Define RHS -----
def uhatprime(t, uhat):
    u = np.fft.ifft(uhat)
    return - 0.5j * (kx**2) * uhat + 1j * np.fft.fft((abs(u)**2) * u)

def vhatprime(t, vhat):
    u = np.fft.ifft(vhat2uhat(t,vhat))
    return  1j * uhat2vhat(t, np.fft.fft((abs(u)**2) * u) )

#------ Initial Conditions -----
u0      = 1./np.cosh(x) #+ 1./np.cosh(x+0.4*L)+1./np.cosh(x-0.4*L) #symmetric or remove jump at wrap-around
uhat0   = np.fft.fft(u0)

#------ Solving for ODE -----
t0 = 0; tf = 10.0;
TimeSpan = [t0, tf]
# nt       = 500 # limit case, barely stable, visible spurious bumps in phase
nt       = 1000 # boring  but stable. smaller step sizes give same picture
vhat0 = uhat2vhat(t0,uhat0)

fig = plt.figure()
fig = plt.figure()
gs = fig.add_gridspec(3, 2)
ax1 = fig.add_subplot(gs[0, :]) 
ax2 = fig.add_subplot(gs[1:, :])
ax1.set_ylim(-0.2,2.5); ax1.set_ylabel("$u$ amplitude")
ax2.set_ylim(-6.4,6.4); ax2.set_ylabel("$u$ angle"); ax2.set_xlabel("$x$")

line1, = ax1.plot(x,u0)
line2, = ax2.plot(x,u0*0)

vhatstream = RK4Stream(vhatprime,[t0,tf],vhat0,nt)

def animate(i):
    t,vhat = vhatstream.next()
    print(f"time {t}")
    u = np.fft.ifft(vhat2uhat(t,vhat))
    line1.set_ydata(np.real(np.abs(u)))
    angles = np.real(np.angle(u))
    # connect the angles over multiple periods
    offset = 0;
    tau = 2*np.pi
    if angles[0] > 1.5: offset = -tau
    if angles[0] < -1.5: offset = tau
    for i,a in enumerate(angles[:-1]):
        diff_a = a-angles[i+1]
        angles[i] += offset
        if diff_a > 2 : 
            offset += tau
            if offset > 9: offset = tau-offset
        if diff_a < -2 : 
            offset -= tau
            if offset < -9: offset = -tau-offset
    angles[-1] += offset
    line2.set_ydata(angles)
    return line1,line2

anim = animation.FuncAnimation(fig, animate, interval=15000/nt+10, blit=False)

plt.show()

可以通过每帧计算多个 RK4 步骤来加快动画速度,从而增加可见的步长。

如果想要使用 scipy.integrate 中的 ODE 求解器,则必须实现一些包装器,因为这些包装器并未针对复杂值数据的使用进行强化。

# the stepper functions can not handle complex valued data
def RK45Stream(odefunc,TimeSpan,uhat0,nt):
    def odefuncreal(t,ureal):
        u = ureal.reshape([2,-1])
        deriv = odefunc(t,u[0]+1j*u[1])
        return  np.concatenate([deriv.real, deriv.imag])
    t0,tf = TimeSpan
    h = float(tf-t0)/nt
    print("step size ", h) 
    w = np.concatenate([uhat0.real, uhat0.imag])
    t = t0
    stepper = RK45(odefuncreal, t0, w, tf, atol=1e-9, rtol=1e-12)
    out_t = t0
    while True:
        t = t+h
        while t > stepper.t: stepper.step()
        if t>out_t: out_t, sol = stepper.t, stepper.dense_output()
        w = sol(t); w=w.reshape([2,-1])
        yield t,w[0]+1j*w[1]

关于python - Python中RK4算法错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29803342/

相关文章:

python - 在数据分析工具 QLikView 中运行 Python 脚本

python - 弹性摆系统的龙格库塔

python - 在找到局部最大值之前,我可以与 scipy 的 odeint 集成吗?

javascript - 将查询集列表传递给带有日期时间字段的 JavaScript 函数

python - 使用 Python PIL 调整大小的图像较暗

java - Java 的 Python 计数器替代品

python - 奇数和代码问题

python - Runge-Kutta 代码不与内置方法收敛

r - 使用RStudio中deSolve包中的dede来求解时滞ODE

matlab - 求解和绘制分段 ODE