module OW_pspi_dd

use OWvel
use OW_type
use OW_image
use OW_fftw
use OW_grabvel
use OW_parms

implicit none

public						:: owpspi_init, owpspi
private						:: isp,ni,inode,node,ixs

integer						:: isp,ni,inode,node,ixs

contains

!---------------------------------------------------------
subroutine owpspi_init(isp_in,ni_in,inode_in,node_in,ixs_in)

integer						:: isp_in,ni_in,inode_in,node_in,ixs_in

isp=isp_in
ni=ni_in
inode=inode_in
node=node_in
ixs=ixs_in

end subroutine
!---------------------------------------------------------
!subroutine owpspi(recv,source,vel,GG)
subroutine owpspi(source,vel,image)

integer						:: iw,iz,ikx,iref,ixw,ish,ino
integer						:: nv,ihso
real						:: sign, sqa,f,g
real						:: minv, maxv,vel(:,:)
real, allocatable				:: vref(:),sqa0(:)
complex, allocatable				:: ws(:),gs0(:),gs(:,:)
complex, pointer				:: tfld(:)
complex						:: source(:,:,:)

complex,allocatable				:: GG(:,:,:,:)
complex						:: image(:,:,:)

sign=-1;

allocate(GG(img%z%n,img%sh%n,rec%h%n,node))

!--------------------------------------------------
! Setting FFTW plans
!
allocate(tfld(kxm%n))

call owfftw(tfld)

omega:do iw = 1,rec%w%n

        if (allocated(ws)) deallocate(ws);	allocate(ws(rec%h%n));		ws=0
tfld=0
        if (allocated(gs)) deallocate(gs);	allocate(gs(rec%h%n,nref));	gs=0
        if (allocated(gs0)) deallocate(gs0);	allocate(gs0(rec%h%n));		gs0=0
        if (allocated(sqa0)) deallocate(sqa0);	allocate(sqa0(rec%h%n));	sqa0=0
        if (allocated(vref)) deallocate(vref);	allocate(vref(nref));		vref=0

        ws(isp)=source(1,iw,1) ! to be modified to handle generalized sources

        depth:do iz=2,img%z%n
tfld=0
!--------------------------------------------------
! Compute vref
!
            call owvref(nref,vel(iz,:),vref,minv,maxv)

            tfld=ws; call sfftw_execute(planF); ws=tfld/sqrt(1.*kxm%n)

           vxz:if (minv .eq. maxv) then
                gs0=ws

!      Operations in the wavenumber domain
!      phase shift
                wavenumber1:do ikx = 1,kxm%n
                   sqa = rec%w%all(iw)**2 / minv**2 - kxm%all(ikx)**2

                   evanescent1:if (sqa > 0) then
                       gs0(ikx)=gs0(ikx)*exp(cmplx(0,-sign*v%z%d*sqrt(sqa)))
                   else
                       gs0(ikx)=gs0(ikx)*exp(-sqrt(-sqa)*v%z%d)
                   end if evanescent1
                end do wavenumber1


                   tfld=gs0; call sfftw_execute(planI); ws=tfld/sqrt(1.*kxm%n)

            else
                velref: do iref=1,nref                 
                    gs(:,iref)=ws
                    wavenumber2: do ikx = 1,kxm%n
!      phase shift
                        sqa = rec%w%all(iw)**2 / vref(iref)**2 - kxm%all(ikx)**2;
                        evanescent2: if (sqa > 0) then
                            gs(ikx,iref)=gs(ikx,iref)*exp(cmplx(0,-sign*v%z%d*sqrt(sqa)));
                        else
                            gs(ikx,iref)=gs(ikx,iref)*exp(-sqrt(-sqa)*v%z%d)
                        end if evanescent2
                    end do wavenumber2

!                      sqa0 = rec%w%all(iw)**2 / vref(iref)**2 - kxm%all**2;
!                      evanescent2:where (sqa0 > 0)
!                          gs(:,iref)=gs(:,iref)*exp(cmplx(0,-sign*v%z%d*sqrt(sqa0)))
!                      elsewhere
!                          gs(:,iref)=gs(:,iref)*exp(-sqrt(-sqa0)*v%z%d)
!                      end where evanescent2

! Operations in x            

                      tfld=gs(:,iref); call sfftw_execute(planI); gs(:,iref)=tfld/sqrt(1.*kxm%n)

 !      split step
                      gs(:,iref) = gs(:,iref) * exp(cmplx(0,-sign*v%z%d*rec%w%all(iw))*(1./vel(iz,:)-1./vref(iref)));
                end do velref

!                do iref=1,nref-1
!                interpwvfld: where (vel(iz,:) >= vref(iref) .and. vel(iz,:) < vref(iref+1))
!                     ws=gs(:,iref)
!                     end where interpwvfld
!                end do
!                where (vel(iz,:) >= vref(nref))
!                      ws=gs(:,nref)
!                end where

                do ikx=1,kxm%n
                    do iref=1,nref-1
                        if (vref(iref) .ne. vref(iref+1)) then
                            if (vel(iz,ikx) >= vref(iref) .and. vel(iz,ikx) < vref(iref+1)) then
                                f=(vel(iz,ikx) - vref(iref))/vel(iz,ikx)
                                g=1-f
                                ws(ikx)=g*gs(ikx,iref)+f*gs(ikx,iref+1)
!                                ws(ikx)=gs(ikx,iref)
                            elseif (vel(iz,ikx) >= vref(nref)) then
                                ws(ikx)=gs(ikx,nref)
                            end if
                        else
                            ws(ikx)=gs(ikx,iref)
                        end if                            
                    end do
                end do
           end if vxz

!
! Imaging Condition
!
            do ish=1,img%sh%n
                ihso=floor(((ish-1)*img%sh%d+img%sh%o)/(img%xm%d))
                do ixw=1,kxm%n
                    if (ixw-ihso < 1 .or. ixw-ihso > kxm%n .or. ixw+ihso < 1 .or. ixw+ihso > kxm%n) then
                    else
                        GG(iz,ish,ixw,inode)=GG(iz,ish,ixw,inode)+conjg(ws(ixw-ihso))*(ws(ixw+ihso))
                    end if
                end do
            end do

!           call owimage_init(ni)
!           call owimage(ws,wr,GG(iz,:,:))

        end do depth
        if (mod(iw,10)==1) write(0,5)  "DONE FREQUENCY # ",iw," node = ",inode," OUT OF ",sou%w%n," / SHOT # ",ixs
5                          format(A17,i5,3(A8,i5))

end do omega

!
! Summing images
!
    do ino=1,node
       image(:,:,ni:ni+kxm%n-1)=image(:,:,ni:ni+kxm%n-1)+GG(:,:,:,ino);
    end do

end subroutine owpspi

end module
