module RWE2D_source_wemig
  use sep
  use util_
  use fft
  use geometry2D
  use sep_timers_mod

  implicit none
  integer, private :: nz,nx,nw,nr,nloop,nref,ntap
  real,    private :: dz,dx,dw,oz,ox,ow,okx,dkx,forward,eps,pi,sixth,c1,c2
  logical, private :: kin,verbose

  integer, allocatable,private :: allmask(:)
  real,    allocatable,private :: msk(:),mtt(:),tap(:),ikz2d(:)
  real,    allocatable,private :: allrefs(:,:),allfields(:,:)
  complex, allocatable,private :: rwfl(:),rwtt(:),rdax(:)
complex, allocatable, private :: lk(:),rk(:),ck(:)
  complex, allocatable, private :: a(:),b(:),c(:),u(:)

  real,                 private :: kmu,   knu,    kro 
  real,    allocatable, private :: mu(  :),nu(  :),ro(  :)
  real,    allocatable, private :: m0(:),n0(:),r0(:)

contains

  !----------------------------------------------------------------  
  subroutine wemig_init(ax,az,aw,ar,forward_in,kin_in,nref_in,verbose_in)
    type(myaxis)     :: ax,az,aw,ar
    integer          ::                               nref_in
    real             ::             forward_in
    logical          ::                        kin_in,        verbose_in
    real             :: pi
    pi=acos(-1.)
    forward=forward_in
    nz=az%n;    nx=ax%n;    nw=aw%n
    dz=az%d;    dx=ax%d;    dw=aw%d * 2.*pi
    oz=az%o;    ox=ax%o;    ow=aw%o * 2.*pi
    nref=nref_in
    
    call from_param("c1",c1,0.5)
    call from_param("c2",c2,0.)
    call from_param("sixth",sixth,0.122996)
    call from_param("ntap",ntap,5)
    call from_param("nloop",nloop,3)
    okx=  -pi/    dx
    dkx=2.*pi/(nx*dx)

    kin=kin_in
    verbose=verbose_in;    
    eps=0.0000001
	 allocate(lk(nx),rk(nx),ck(nx))
    allocate(a(nx),b(nx-1),c(nx-1),u(nx))

    kmu = 1./ dz         !! 1/ dz
    knu = 1./(2.*dx**2)  !! 1/( 2 dx^2)
    kro = 1./(dz*dx**2)  !! 1/(dz dx^2)
     allocate(m0(nx),n0(nx),r0(nx))
    allocate(mu(nx),nu(nx),ro(nx))

    allocate(rdax(nx))
    allocate(ikz2d(nx))
    allocate(rwfl(nx))
    allocate(rwtt(nx),msk(nx));
    allocate(mtt(nx),tap(nx))
    allocate(allmask  (nx)   )
    allocate(allfields(nx,4) ) 
    allocate(allrefs(nref,4))
    msk=1.;mtt=1.;tap=1.;ikz2d=1.
    call rwetaper_init()

  end subroutine wemig_init

  !----------------------------------------------------------------

  !! WE migration driver
  subroutine wemig(rwf,Pimg,iss,cutfields,refs,cutmask)
    integer          :: iw,iz,ix,ir,iss,cutmask(:,:),t1,t2,t3,t4,t5,t6
    real             :: rarg,w
    real             :: Pimg(:,:),refs(:,:,:),cutfields(:,:,:)
    complex,pointer  :: rwf(:,:),rax(:)
    logical          :: logic

    logic=init_sep_timers()
    logic=setup_next_timer("WLoop",t1)
    logic=setup_next_timer("PHS",t2)
    logic=setup_next_timer("SSF",t3)
    logic=setup_next_timer("FFD",t4)
    logic=setup_next_timer("REF",t5)

    !! . . Forward Scattering Option
    rarg= 1.
    write(0,*) 'STARTING MIGRATION'
    write(0,*) 'USING REFERENCE MAX OF :',nref

    do iw=1,nw  !! Frequency Loop           
       call start_timer_num(t1)

       w=ow+(iw-1)*dw
       rax => rwf(:,iw)

       call rwetaper(rax) 


       do iz=2,nz
		!write(0,*) 'a1',iz,iw,maxval(abs(rax))

          call start_timer_num(t3)
          call wemig_ssf(rax,w,cutmask(:,iz),cutfields(:,:,iz),refs(:,:,iz),rarg)
          call stop_timer_num(t3)

			!write(0,*) 'a2',iz,iw,maxval(abs(rax))
          rwtt=rax;rax=0. 

          call fth(.false.,.false.,rwtt)
	  	  do ir=1,nref  
             if (refs(ir,2,iz) .lt. 0.000000001) cycle 
             rwfl=rwtt

             call start_timer_num(t2)
             call wemig_phs(rwfl,w,iz,ir,refs(:,:,iz),rarg)
             call stop_timer_num(t2)

             call fth( .true.,.false.,rwfl)

             !             call start_timer_num(t4)
             !             call wemig_psc_coef(cutfields(:,2,iz),cutfields(:,3,iz),&
             !                                     refs(ir,2,iz),    refs(ir,3,iz))
             !             call wemig_fds(rwfl,w,iz,rarg)
             !             call stop_timer_num(t4)
             call start_timer_num(t5)
             call wemig_ref(rwfl,iz,ir,cutmask(:,iz),rax)
             call stop_timer_num(t5)

          end do
          call rwetaper(rax) 

          Pimg(iz,:)=Pimg(iz,:)+real( rax )

       end do !! . . z-loop

       write(0,*) "SSF",iw,nw,iss,maxval(abs(Pimg(:,:)))
       call stop_timer_num(t1)

    end do  !! . . w-loop
    write(0,*) 'FINISHED OFF SHOT!'

    call print_timers()
  end subroutine wemig

  !----------------------------------------------------------------
  !! split-step Fourier
  subroutine wemig_ssf(rdax,w,allmask,allfields,allrefs,rdir) 
    integer :: ix,allmask(:)
    real    :: w,rdir,allfields(:,:),allrefs(:,:)
    complex :: rdax(:)
    do ix=1,nx ! was -
       ikz2d(ix)=-dz*(w**2*allrefs(allmask(ix),2 )*&
         (allfields(ix,2 )-allrefs(allmask(ix),2))-&
          allrefs(allmask(ix),4)*(allfields(ix,4)-allrefs(allmask(ix),4)))/&
          sqrt(abs(allrefs(allmask(ix),2)**2*w**2-allrefs(allmask(ix),4)**2)+eps ) 
    end do
    rdax = rdax*cmplx( cos(ikz2d) ,rdir*sin(ikz2d) )

  end subroutine wemig_ssf
  !----------------------------------------------------------------  

  !! Fourier-domain phase shift
  subroutine wemig_phs(rdax,w,iz,ir,allrefs,rdir)
    integer :: ix,iz,ir
    real    :: w,kx,rdir,carg,carg1,allrefs(:,:),pkz
    complex :: kz,rdax(:)

    carg1= (allrefs(ir,2)* w)**2 - allrefs(ir,4)**2
 
    do ix=1,nx
       kx=okx+(ix-1)*dkx

       carg = carg1 - (allrefs(ir,3)*kx)**2 
       
       if ( carg .lt. 0) then
          pkz = -dz*abs(sqrt(-carg))
          rdax(ix)=rdax(ix)*exp(pkz)
        
       else

          pkz =-dz*( sqrt(carg) )
          rdax(ix)=rdax(ix)*cmplx(cos(pkz),rdir*sin(pkz))

       end if

       pkz =  dz*kx*allrefs(ir,1)
       rdax(ix)=rdax(ix)*cmplx(cos(pkz),rdir*sin(pkz))

    end do

  end subroutine wemig_phs


   !----------------------------------------------------------------

  subroutine wemig_ref(rwfl,iz,ir,allmask,rdax)
    complex :: rdax(:),rwfl(:)
    integer :: iz,ir,ix,ii,allmask(:)

    mtt=0.
    where( allmask .eq. ir) 
       mtt=1.
    end where

    do ii=1,nloop !! Smoothing Iterations
       msk(1 ) = mtt(1   )
       msk(nx) = mtt(nx-1)
       do ix=2,nx-1
          msk(ix)=( mtt(ix-1)+mtt(ix+1) )/2.
       end do
       mtt=msk
    end do

    rdax = rdax + msk*rwfl

  end subroutine wemig_ref

  !----------------------------------------------------------------

  subroutine rwetaper_init()
    integer :: itap,jtap,ix
    real :: pi
    pi=acos(-1.)

    if (ntap .gt. 1) then

       do itap=1,ntap
          jtap = abs(ntap-itap)
          tap(   itap) = cos(pi/2.* jtap/(ntap-1))
       end do

       do itap=1,ntap
          jtap = abs(ntap-itap)
          tap(nx-itap+1) = cos(pi/2.* jtap/(ntap-1))
       end do

    end if

  end subroutine rwetaper_init

  !----------------------------------------------------------------

  subroutine rwetaper(dax)
    complex, pointer  :: dax(:)
    dax = dax * tap
  end subroutine rwetaper

  !----------------------------------------------------------------
!----------------------------------------------------------------
!! . . REINSERTED FD STUFF 
!----------------------------------------------------------------
  subroutine wemig_psc_coef(aa,bb,a0,b0)
    real    :: aa(:),bb(:)
    real    :: a0,  b0
    integer :: iz

    n0 = -c1 * (b0/a0)**2 *(aa-a0)+ b0/a0   *(bb-b0)
    r0 = 3.*c2* (b0/a0)**2 
    m0 = 1.

    m0 = m0 * kmu
    n0 = n0 * knu
    r0 = r0 * kro

  end subroutine wemig_psc_coef
 !----------------------------------------------------------------

  subroutine wemig_ffd_coef(aa,bb,a0,b0)
    real    :: aa(:),bb(:)
    real    :: a0,b0
    integer :: iz
    real,    allocatable :: d1(:),d2(:)
    real    :: tt

    allocate(d1(nx),d2(nx))

   !! tt = b0
 
    d1=a0*((b0/tt)/a0)**2 * ( ( bb/b0)**2* a0/aa    - 1. )
    d2=a0*((b0/tt)/a0)**4 * ( ( bb/b0)**4*(a0/aa)**3- 1. )

    where(abs(d1)<eps ) 
       d1=eps
    end where

    m0 =      1    
    n0 = c1 * d1 ! * tt**2
    r0 = c2 !* d2/d1 * tt**2

    m0 = m0 * kmu
    n0 = n0 * knu
    r0 = r0 * kro

    deallocate(d1,d2)
  end subroutine wemig_ffd_coef
  !----------------------------------------------------------------
  subroutine wemig_coef_report()

    write(0,*) minval(m0),"< mu <",maxval(m0)
    write(0,*) minval(n0),"< nu <",maxval(n0)
    write(0,*) minval(r0),"< ro <",maxval(r0)

  end subroutine wemig_coef_report
  !----------------------------------------------------------------  
  !! finite-differences solver
  subroutine wemig_fds(dax,w,iz,dir)
    real                  :: w,dir
    integer               :: iz
    complex, dimension(:) :: dax
    integer               :: ix

    mu = m0
    nu = n0 / w
    ro = r0 / w**2

    lk = cmplx(ro+sixth*mu, dir*nu)
    rk = cmplx(ro+sixth*mu,-dir*nu)
    ck = cmplx(         mu, 0.)
!    lk = cmplx(ro+sixth*mu, nu)
!    rk = cmplx(ro+sixth*mu,-nu)
!    ck = cmplx(         mu, 0.)

    !!             |                      |
    !!     b       |          a           |     c
    !!  ro + i nu  |  mu - 2(ro + i nu )  |  ro + i nu
    !!             |                      |
    a = ck - 2.*lk  
    b = lk(1  :nx-1)
    c = lk(1+1:nx  )

    do ix=2,nx-1
       u(ix)=           rk(ix) *(dax(ix-1)+dax(ix+1)) & 
       +    (ck(ix)-2.0*rk(ix))* dax(ix)
       !!             |                      | 
       !!  ro - i nu  |  mu - 2 (ro - i nu)  | ro - i nu
       !!             |                      | 
    end do
    ix= 1; u(ix)=(ck(ix)-2.0*rk(ix))*dax(ix) + rk(ix)*dax(ix+1)
    ix=nx; u(ix)=(ck(ix)-2.0*rk(ix))*dax(ix) + rk(ix)*dax(ix-1)

    call thr(a,b,c,u,nx)
    dax = u
  end subroutine wemig_fds
  !----------------------------------------------------------------
  !! tridiagonal solver
  subroutine thr(a,b,c,v,n)
    integer :: n
    complex :: a(n),b(n-1),c(n-1),v(n)
    integer :: i

    b(1)=b(1)/a(1)
    v(1)=v(1)/a(1)
    do i=2,n
       a(i)=a(i)-c(i-1)*b(i-1)
       b(i)=b(i)/a(i)
       v(i)=(v(i)-c(i-1)*v(i-1))/a(i)
    end do
    do i=n-1,1,-1
       v(i)=v(i)-b(i)*v(i+1)
    end do
  end subroutine thr
  !----------------------------------------------------------------
  subroutine attribute(x)
    complex :: x(:)
    write(0,*) minval( real(x)),maxval( real(x))
    write(0,*) minval(aimag(x)),maxval(aimag(x))
  end subroutine attribute

end module  RWE2D_source_wemig
