Welcome to our community

Be a part of something great, join today!

[SOLVED] Pseudo-spectral with RK4 integration

dwsmith

Well-known member
Feb 1, 2012
1,673
I using the pseudo-spectral method with RK4 integration to solve the nonlinear KdV equation.

My code works for both the linear KdV and the NLS equation; however, it nots working right for \(u_t + u_{xxx} + 6uu_x = u_t + u_{xxx} + 3(u^2)_x = 0\). I want to view from \(0\leq t\leq 10\) and \(-40\leq x\leq 40\) so \(L = 80\). I want accuracy of \(10^{-5}\) or less. The accuracy of RK4 is of order 4 and the accuracy of the pseudo spectral method is \(e^{-x/(\Delta x)}\).

I used the identity \(\mathcal{F}(u^2) = \mathcal{F}^{-1}(u)\mathcal{F}^{-1}(u)\).
The problem has to be related to the inverse transform setup of \((u^2)_x\).
Below you will see the code work with the plots and the code not work.

Numerical instability occurs at \(\Delta t < \frac{2\sqrt{2}(\Delta x)^2}{\pi^2}\) but if I choose a stable value for \(\Delta t\) the plots don't even come close to correct.
[HR][/HR][HR][/HR]Code in Python linear:
Code:
#!/usr/bin/env ipython 
#  The pseudo-spectral method for solving the Linear KdV equation 
#  u_t + u_{xxx} = 0 
 
import matplotlib.pyplot as plt 
import numpy as np 
 
L = 200.0 
N = 512.0 
dt = 0.005 
tmax = 5 
nmax = int(np.floor(tmax / dt))  #  also try ceiling 
dx = L / N 
x = np.arange(-L / 2.0, L / 2.0 - dx/2., dx) 
k = np.hstack((np.arange(0,N / 2.0 - .1),np.arange(-N/2., 0))).T * 2.0 * np.pi / L 
k3 = k ** 3 
FWHM = 0.3 * np.pi 
alpha = np.sqrt(0.5) 
u = alpha * np.exp(-x ** 2 / (2 * FWHM ** 2)) 
udata = u 
tdata = 0 
 
for nn in range(1, nmax+1): 
    du1 = 1j * np.fft.ifft(k3 * np.fft.fft(u)) 
    v = u + 0.5 * du1 * dt 
    du2 = 1j * np.fft.ifft(k3 * np.fft.fft(v)) 
    v = u + 0.5 * du2 * dt 
    du3 = 1j * np.fft.ifft(k3 * np.fft.fft(v)) 
    v = u + du3 * dt 
    du4 = 1j * np.fft.ifft(k3 * np.fft.fft(v)) 
    u = u + (du1 + 2 * du2 + 2 * du3 + du4) * dt / 6.0 
    if np.mod(nn, np.floor(nmax / 100.0)) == 0: 
        udata = np.vstack([udata, u]) 
        tdata = np.vstack([tdata, nn * dt]) 
 
plt.pcolor(x, tdata.ravel(), np.real(udata)) 
plt.xlabel('$x$') 
plt.ylabel('Time') 
plt.title('Linear Dispersion in the $U_t+U_{xxx}=0$ Equation',fontsize=13) 
plt.show()

[HR][/HR]Code in Matlab NLS:
Code:
% p2.m: the pseudo-spectral method for solving the NLS equation
% iu_t+u_{xx}+2|u|^2u=0.

  L = 80; 
  N = 256; 
  dt = 0.02;  
  tmax = 20; 
  nmax = round(tmax/dt);
  dx = L/N; 
  x = (-L/2:dx:L/2-dx)'; 
  k = [0:N/2-1 -N/2:-1]'*2*pi/L; 
  k2 = k.^2;
  u = 1.2*sech(1.2*(x + 20)).*exp(1i*x) + 0.8*sech(0.8*x);
  udata = u; 
  tdata = 0;
  
  for nn = 1:nmax                               % integration begins
    du1 = 1i*(ifft(-k2.*fft(u)) + 2*u.*u.*conj(u));  
    v = u + 0.5*du1*dt;
    du2 = 1i*(ifft(-k2.*fft(v)) + 2*v.*v.*conj(v));  
    v=u+0.5*du2*dt;
    du3 = 1i*(ifft(-k2.*fft(v)) + 2*v.*v.*conj(v));  
    v = u + du3*dt;
    du4 = 1i*(ifft(-k2.*fft(v)) + 2*v.*v.*conj(v));
    u = u + (du1 + 2*du2 + 2*du3 + du4)*dt/6;
    if mod(nn, round(nmax/25)) == 0
       udata = [udata u]; 
       tdata = [tdata nn*dt];
    end
  end
  % integration ends
  
  waterfall(x, tdata, abs(udata'));           % solution plotting
  colormap(jet(128)); view(10, 60)
  text(-2,  -6, 'x', 'fontsize', 15)
  text(50, 5, 't', 'fontsize', 15)
  zlabel('|u|', 'fontsize', 15)
  axis([-L/2 L/2 0 tmax 0 2]); grid off
  set(gca, 'xtick', [-40 -20 0 20 40])
  set(gca, 'ytick', [0 10 20])
  set(gca, 'ztick', [0 1 2])

[HR][/HR][HR][/HR]Code in python nonlinear:
Code:
#!/usr/bin/env ipython
#  The pseudo-spectral method for solving the nonLinear KdV equation
#  u_t + u_{xxx} + 6uu_x = 0

import numpy as np
import pylab

L = 80.0
N = 200.0
dt = 0.02   # and 0.05
tmax = 10
nmax = int(np.floor(tmax / dt))  #  also try ceil/floor
dx = L / N
x = np.arange(-L / 2.0, L / 2.0 - dx, dx)
k = np.hstack((np.arange(0, N / 2.0 - 1.0),
               np.arange(-N / 2.0, 0))).T * 2.0 * np.pi / L
k1 = 1j * k
k3 = (1j * k) ** 3
u = 2 * (1 / (np.exp(x + 20.0) + np.exp(-x - 20.0))) ** 2
udata = u
tdata = 0.0

for nn in range(1, nmax + 1):
    du1 = (-np.fft.ifft(k3 * np.fft.fft(u)) -
           3 * np.fft.ifft(k1 * np.fft.ifft(u) * np.fft.ifft(u)))
    v = u + 0.5 * du1 * dt
    du2 = (-np.fft.ifft(k3 * np.fft.fft(v)) -
           3 * np.fft.ifft(k1 * np.fft.ifft(v) * np.fft.ifft(v)))
    v = u + 0.5 * du2 * dt
    du3 = (-np.fft.ifft(k3 * np.fft.fft(v)) -
           3 * np.fft.ifft(k1 * np.fft.ifft(v) * np.fft.ifft(v)))
    v = u + du3 * dt
    du4 = (-np.fft.ifft(k3 * np.fft.fft(v)) -
           3 * np.fft.ifft(k1 * np.fft.ifft(v) * np.fft.ifft(v)))
    u = u + (du1 + 2.0 * du2 + 2.0 * du3 + du4) * dt / 6.0
    if np.mod(nn, np.ceil(nmax / 100.0)) == 0:
        udata = np.vstack((udata, u))
        tdata = np.vstack((tdata, nn * dt))


fig = pylab.figure()
ax = fig.add_subplot(111)
ax.pcolor(x, tdata.ravel(), np.real(udata))
pylab.xlim((-40, 40))
pylab.ylim((0, 2))
pylab.show()
Numerical instability plot:

Stable Plot:


[HR][/HR]Code in Matlab nonlinear:
Code:
% p2.m: the pseudo-spectral method for solving the KdV equation
% u_t+u_{xxx}+3(u^2)_x=0.

  L = 80; 
  N = 100; 
  dt = 0.02;  % and 0.05
  tmax = 10; 
  nmax = round(tmax/dt);
  dx = L/N; 
  x = (-L/2:dx:L/2-dx)'; 
  k = [0:N/2-1 -N/2:-1]'*2*pi/L; 
  k1 = 1i*k;
  k2 = (1i*k).^2;
  k3 = (1i*k).^3;
  u = 2*sech(x + 20).^2;
  udata = u;
  tdata = 0;
  
  for nn = 1:nmax                               % integration begins
    du1 = -ifft(k3.*fft(u)) - 3*ifft(k1.*ifft(u).*ifft(u));  
    v = u + 0.5*du1*dt;
    du2 = -ifft(k3.*fft(v)) - 3*ifft(k1.*ifft(v).*ifft(v));  
    v = u + 0.5*du2*dt;
    du3 = -ifft(k3.*fft(v)) - 3*ifft(k1.*ifft(v).*ifft(v));  
    v = u + du3*dt;
    du4 = -ifft(k3.*fft(v)) - 3*ifft(k1.*ifft(v).*ifft(v));
    u = u + (du1 + 2*du2 + 2*du3 + du4)*dt/6;
    if mod(nn, round(nmax/15)) == 0
       udata = [udata u]; 
       tdata = [tdata nn*dt];
    end
  end
  % integration ends
  
  waterfall(x, tdata, abs(udata'));           % solution plotting
  colormap(jet(128)); 
  view(10, 60)
  %text(-2,  -6, 'x', 'fontsize', 15)
  axis([-40 40 0 tmax 0 2]); 
  grid off
Numerical unstable plot:

Stable plot:
 
Last edited: