module OW_pspi

use OWvel
use OW_type
use OW_image
use OW_fftw
use OW_grabvel
use OW_parms

implicit none

public						:: owpspi_init, owpspi
private						:: isp,ni,inode,node,ixs

integer						:: isp,ni,inode,node,ixs
real						:: wb

contains

!---------------------------------------------------------
subroutine owpspi_init(isp_in,ni_in,inode_in,node_in,ixs_in,wb_in)

integer						:: isp_in,ni_in,inode_in,node_in,ixs_in
real						:: wb_in

isp=isp_in
ni=ni_in
inode=inode_in
node=node_in
ixs=ixs_in
wb=wb_in

end subroutine
!---------------------------------------------------------
subroutine owpspi(recv,source,vel,GG)
!subroutine owpspi(recv,source,vel,image)

integer						:: iw,iz,ikx,iref,ixw,ish,ino
integer						:: nv,ihso,nwb
integer,allocatable					:: indref(:)
real						:: sign, sqa,f,g
real						:: minv, maxv,vel(:,:)
real, allocatable					:: vref(:),sqa0(:)
complex, allocatable				:: ws(:),wr(:),gs0(:),gr0(:),gs(:,:),gr(:,:),tfld(:)
complex						:: recv(:,:),source(:,:,:)
complex 						:: GG(:,:,:)

!complex,allocatable					:: GG(:,:,:,:)
!complex						:: image(:,:,:)

sign=-1;

!allocate(GG(img%z%n,img%sh%n,rec%h%n,node))

!--------------------------------------------------
! Setting FFTW plans
!
if (allocated(tfld)) deallocate(tfld);	allocate(tfld(rec%h%n));	tfld=0
call owfftw(tfld)

nwb=floor(wb/img%z%d)+1

omega:do iw = 1,rec%w%n

        if (allocated(ws)) deallocate(ws);	allocate(ws(rec%h%n));	ws=0
        if (allocated(wr)) deallocate(wr);	allocate(wr(rec%h%n));	wr=0
        if (allocated(gs0)) deallocate(gs0);	allocate(gs0(rec%h%n));	gs0=0
        if (allocated(gr0)) deallocate(gr0);	allocate(gr0(rec%h%n));	gr0=0
        if (allocated(gs)) deallocate(gs);	allocate(gs(rec%h%n,nref));	gs=0
        if (allocated(gr)) deallocate(gr);	allocate(gr(rec%h%n,nref));	gr=0
        if (allocated(vref)) deallocate(vref);	allocate(vref(nref));	vref=0
        if (allocated(sqa0)) deallocate(sqa0);	allocate(sqa0(rec%h%n));	sqa0=0

        ws(isp)=source(1,iw,1) ! to be modified to handle generalized sources
        wr=recv(:,iw) 

        depth:do iz=nwb+1,img%z%n
!--------------------------------------------------
! Compute vref
!
            if (maxval(vel(iz,:)).gt.vsal) then
               if (allocated(vref)) deallocate(vref);allocate(vref(nref+1));     vref=0
               if (allocated(gs)) deallocate(gs);allocate(gs(rec%h%n,nref+1)); gs=0
               if (allocated(gr)) deallocate(gr);allocate(gr(rec%h%n,nref+1)); gr=0
               call owvref(nref,vel(iz,:),vref(1:nref),minv,maxv)
               vref(nref+1)=maxval(vel(iz,:))
            else
               if (allocated(vref)) deallocate(vref);allocate(vref(nref));		vref=0
               if (allocated(gs)) deallocate(gs);allocate(gs(rec%h%n,nref));	gs=0
               if (allocated(gr)) deallocate(gr);allocate(gr(rec%h%n,nref));	gr=0
               call owvref(nref,vel(iz,:),vref,minv,maxv)
            end if



            tfld=ws; call sfftw_execute(planF); ws=tfld/sqrt(1.*kxm%n)
            tfld=wr; call sfftw_execute(planF); wr=tfld/sqrt(1.*kxm%n)

           vxz:if (minv .eq. maxv .and. maxv.eq.vref(size(vref))) then
                gs0=ws;  gr0=wr;

!      Operations in the wavenumber domain
!      phase shift
                wavenumber1:do ikx = 1,kxm%n
                   sqa = rec%w%all(iw)**2 / minv**2 - kxm%all(ikx)**2

                   evanescent1:if (sqa > 0) then
                       gs0(ikx)=gs0(ikx)*exp(cmplx(0,v%z%d*sqrt(sqa)))
                       gr0(ikx)=gr0(ikx)*exp(cmplx(0,-v%z%d*sqrt(sqa)))
                   else
                       gs0(ikx)=gs0(ikx)*exp(-sqrt(-sqa)*v%z%d)
                       gr0(ikx)=gr0(ikx)*exp(-sqrt(-sqa)*v%z%d)
                   end if evanescent1
                end do wavenumber1


                   tfld=gs0; call sfftw_execute(planI); ws=tfld/sqrt(1.*kxm%n)
                   tfld=gr0; call sfftw_execute(planI); wr=tfld/sqrt(1.*kxm%n)


            else
                velref: do iref=1,size(vref)                 
                    gs(:,iref)=ws;  gr(:,iref)=wr; 
                    wavenumber2: do ikx = 1,kxm%n
!      phase shift
                        sqa = rec%w%all(iw)**2 / vref(iref)**2 - kxm%all(ikx)**2;
                        evanescent2: if (sqa > 0) then
                            gs(ikx,iref)=gs(ikx,iref)*exp(cmplx(0,v%z%d*sqrt(sqa)));
                            gr(ikx,iref)=gr(ikx,iref)*exp(cmplx(0,-v%z%d*sqrt(sqa)));
                        else
                            gs(ikx,iref)=gs(ikx,iref)*exp(-sqrt(-sqa)*v%z%d)
                            gr(ikx,iref)=gr(ikx,iref)*exp(-sqrt(-sqa)*v%z%d)
                        end if evanescent2
                    end do wavenumber2

!                      sqa0 = rec%w%all(iw)**2 / vref(iref)**2 - kxm%all**2;
!                      evanescent2:where (sqa0 > 0)
!                          gs(:,iref)=gs(:,iref)*exp(cmplx(0,v%z%d*sqrt(sqa0)))
!                          gr(:,iref)=gr(:,iref)*exp(cmplx(0,-v%z%d*sqrt(sqa0)))
!                      elsewhere
!                          gs(:,iref)=gs(:,iref)*exp(-sqrt(-sqa0)*v%z%d)
!                          gr(:,iref)=gr(:,iref)*exp(-sqrt(-sqa0)*v%z%d)
!                      end where evanescent2

! Operations in x            

                      tfld=gs(:,iref); call sfftw_execute(planI); gs(:,iref)=tfld/sqrt(1.*kxm%n)
                      tfld=gr(:,iref); call sfftw_execute(planI); gr(:,iref)=tfld/sqrt(1.*kxm%n)


 !      split step
                      gs(:,iref) = gs(:,iref) * exp(cmplx(0,v%z%d*rec%w%all(iw))*(1./vel(iz,:)-1./vref(iref)));
                      gr(:,iref) = gr(:,iref) * exp(cmplx(0,-v%z%d*rec%w%all(iw))*(1./vel(iz,:)-1./vref(iref)));
                end do velref

!                do iref=1,nref-1
!                interpwvfld: where (vel(iz,:) >= vref(iref) .and. vel(iz,:) < vref(iref+1))
!                     ws=gs(:,iref)
!                     wr=gr(:,iref)
!                     end where interpwvfld
!                end do
!                where (vel(iz,:) >= vref(nref))
!                      ws=gs(:,nref)
!                      wr=gr(:,nref)
!                end where

                do ikx=1,kxm%n
!                    if (iz.eq.560) write(2,123) vel(iz,ikx),vref,maxloc(vref, mask = vref.lt.vel(iz,ikx)) 
!123                                format(6f7.1,i5)
                    if(allocated(indref)) deallocate(indref); allocate(indref(1))
                    indref=maxloc(vref, mask = vref.lt.vel(iz,ikx)) 
                    if (indref(1).ne.0.and.indref(1).ne.size(vref)-1) then
                       f=(vel(iz,ikx) - vref(indref(1)))/vel(iz,ikx)
                       g=1-f
                       ws(ikx)=g*gs(ikx,indref(1))+f*gs(ikx,indref(1)+1)
                       wr(ikx)=g*gr(ikx,indref(1))+f*gr(ikx,indref(1)+1)
                    else
                       ws(ikx)=gs(ikx,indref(1)+1)
                       wr(ikx)=gr(ikx,indref(1)+1)
                    end if
                end do
           end if vxz

!
! Imaging Condition
!
            do ish=1,img%sh%n
                ihso=floor(((ish-1)*img%sh%d+img%sh%o)/(img%xm%d))
                do ixw=1,kxm%n
                    if (ixw-ihso < 1 .or. ixw-ihso > kxm%n .or. ixw+ihso < 1 .or. ixw+ihso > kxm%n) then
                    else
!                        GG(iz,ish,ixw,inode)=GG(iz,ish,ixw,inode)+conjg(ws(ixw-ihso))*(wr(ixw+ihso))             ! correlation to image
                        GG(iz,ish,ixw)=GG(iz,ish,ixw)+conjg(ws(ixw-ihso))*(wr(ixw+ihso))             ! correlation to image
                    end if
                end do
            end do

!           call owimage_init(ni)
!           call owimage(ws,wr,GG(iz,:,:))

        end do depth
        if (mod(iw,10)==1) write(0,5)  "DONE FREQUENCY # ",iw," node = ",inode," OUT OF ",rec%w%n," / SHOT # ",ixs
5                          format(A17,i5,3(A8,i5))

end do omega

!
! Summing images
!
!    do ino=1,node
!       image(:,:,ni:ni+kxm%n-1)=image(:,:,ni:ni+kxm%n-1)+GG(:,:,:,ino);
!    end do

end subroutine owpspi

end module
