module Pre2D_salt_ssf_mod

  use Pre2D_salt_ssf_types

  use gendown3d_types
  use down_process_types
  use data_types
  use image_types
  use slow_types

  use fft_data_mod

  implicit none

  type (Pre2D_salt_ssf_param_type), private :: Pre2D_salt_ssf_par_loc
  type (fft_type), private :: fft_3d_id
  logical, private :: time_mig
  real, private :: slow_const,slow_salt,epsi,min_prop_cos_sq
  real, private :: d_k_norm,max_disp_rel,adj_slow
  real, allocatable, dimension (:), private :: magn,phase
  real, allocatable, dimension (:), private :: k_mx,k_my,k_hx
  real, allocatable, dimension (:,:,:,:), private :: k_sr_sq

contains

  subroutine init_Pre2D_salt_ssf_par(Pre2D_salt_ssf_par)
    type (Pre2D_salt_ssf_param_type) Pre2D_salt_ssf_par

    Pre2D_salt_ssf_par_loc=Pre2D_salt_ssf_par

  end subroutine init_Pre2D_salt_ssf_par

  function Pre2D_salt_ssf_init(gen_par,data_par,slow_par)  result(info)
    type (gendown3d_info_type) info
    type (gendown3d_param_type) gen_par
    type (data_param_type) data_par
    type (slow_param_type) slow_par

    integer i_k_mx,i_k_my,i_k_hx,ierr,i_k_norm
    real kmx,kmy,khx,pi2,k_norm_sq,k_norm
    real n_cycles_ksq,k_start,k_cutoff,k_decay,sigma_magn_sq,edge_gauss,gauss_magn
    real delta_k

    info=info_default

    if(slow_par%n_comp_slow /= 2) then
       write(0,*)'Reset n_comp_slow to 2'
       slow_par%n_comp_slow = 2
    end if

    select case (Pre2D_salt_ssf_par_loc%ref_slow(1:3)) 
    case ('min')
      adj_slow=1.
    case ('max')
      adj_slow=0.
    case ('avg')
      adj_slow=.5
    case default
      call seperr('Not supported reference_slowness ')
    end select

    time_mig=gen_par%time_mig
    slow_const=gen_par%slow_const
    pi2=2.*acos(-1.)
    min_prop_cos_sq=cos(gen_par%max_propag_angle*pi2/360.)**2
    epsi=epsilon(1.)

    if (slow_par%slowness) then
       if(Pre2D_salt_ssf_par_loc%slow_salt > 0.) then
          slow_salt=1.01*(Pre2D_salt_ssf_par_loc%slow_salt)
          if(Pre2D_salt_ssf_par_loc%n_slow_split < 2) call seperr('n_slow_split <2')
       else 
          slow_salt=-1.
       end if
    else
       if(Pre2D_salt_ssf_par_loc%vel_salt > 0.) then
          slow_salt=1.01*(1./Pre2D_salt_ssf_par_loc%vel_salt)
          if(Pre2D_salt_ssf_par_loc%n_slow_split < 2) call seperr('n_slow_split <2')
       else 
          slow_salt=-1.
       end if
    end if

    if(gen_par%debug1) write(0,*)'Init slow_salt=',slow_salt

    if(gen_par%debug1) write(0,*)'min_prop_cos_sq=',min_prop_cos_sq
    if(gen_par%debug1) write(0,*)'n_mx_pad=',data_par%n_mx_pad
    if(gen_par%debug1) write(0,*)'n_my_pad=',data_par%n_my_pad
    if(gen_par%debug1) write(0,*)'n_hx_pad=',data_par%n_hx_pad

    fft_3d_id=init_fft_3d(data_par%n_mx_pad,data_par%n_my_pad, &
                          data_par%n_hx_pad,fast_init=.false.)
    info%status=fft_3d_id%ierr

    if(gen_par%debug1) write(0,*)'scale=',fft_3d_id%scale

    if(data_par%n_hy_pad /= 1 ) then
       call seperr(" Pre2D: Trying to use Prestack migration with n_hy > 1")
    end if

    allocate (phase(Pre2D_salt_ssf_par_loc%n_disp_rel),stat=ierr)  
    if(ierr .ne. 0) call seperr("Problems with allocation of phase ")
    allocate (magn(Pre2D_salt_ssf_par_loc%n_disp_rel),stat=ierr)  
    if(ierr .ne. 0) call seperr("Problems with allocation of magn ")

    allocate (k_mx(data_par%n_mx_pad),stat=ierr)  
    if(ierr .ne. 0) call seperr("Problems with allocation of k_mx ")
    allocate (k_my(data_par%n_my_pad),stat=ierr)  
    if(ierr .ne. 0) call seperr("Problems with allocation of k_my ")
    allocate (k_hx(data_par%n_hx_pad),stat=ierr)  
    if(ierr .ne. 0) call seperr("Problems with allocation of k_hx ")

    allocate (k_sr_sq(data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad,2), &
    stat=ierr)  
    if(ierr .ne. 0) call seperr("Problems with allocation of k_sr_sq ")

!    write(0,*)'d_kmx=',data_par%d_kmx
!    write(0,*)'d_kmy=',data_par%d_kmy
!    write(0,*)'kmx_0=',data_par%kmx_0
!    write(0,*)'kmy_0=',data_par%kmy_0

    do i_k_mx=1,data_par%n_mx_pad
       kmx  = i_k_mx - 1.
       if(kmx <= float(data_par%n_mx_pad/2)) then
          k_mx(i_k_mx) =  data_par%kmx_0+data_par%d_kmx* kmx
       else
          k_mx(i_k_mx) = -data_par%kmx_0+data_par%d_kmx*(kmx - data_par%n_mx_pad)
       end if
    end do

    do i_k_my=1,data_par%n_my_pad
       kmy = i_k_my - 1.
       if(kmy <= float(data_par%n_my_pad/2)) then
          k_my(i_k_my) =  data_par%kmy_0+data_par%d_kmy* kmy
       else
          k_my(i_k_my) = -data_par%kmy_0+data_par%d_kmy*(kmy - data_par%n_my_pad)
       end if
    end do

    if(gen_par%single_hx_dip) then
       do i_k_hx=1,data_par%n_hx_pad
          khx = i_k_hx - 1.
          k_hx(i_k_hx) =  data_par%khx_0+data_par%d_khx* khx
       end do
    else
       do i_k_hx=1,data_par%n_hx_pad
          khx = i_k_hx - 1.
          if(khx <= float(data_par%n_hx_pad/2)) then
             k_hx(i_k_hx) =  data_par%khx_0+data_par%d_khx* khx
          else
             k_hx(i_k_hx) = -data_par%khx_0+data_par%d_khx*(khx - data_par%n_hx_pad)
          end if
       end do
    end if
    !write(0,*)'k_hx=',k_hx

    do i_k_hx=1,data_par%n_hx_pad
       do i_k_my=1,data_par%n_my_pad
          do i_k_mx=1,data_par%n_mx_pad
             k_sr_sq(i_k_mx,i_k_my,i_k_hx,SOU_COMP)= &
             .25*((k_mx(i_k_mx) - k_hx(i_k_hx))**2)
             k_sr_sq(i_k_mx,i_k_my,i_k_hx,REC_COMP)= &
             .25*((k_mx(i_k_mx) + k_hx(i_k_hx))**2)
             !write(0,*)'k_s_sq(',i_k_hx,') =',k_s_sq(i_k_mx,i_k_my,i_k_hx) 
             !write(0,*)'k_r_sq(',i_k_hx,') =',k_r_sq(i_k_mx,i_k_my,i_k_hx) 
             !write(0,*)'k_my(',i_k_my,') =',k_mx(i_k_my) 
             !write(0,*)'k_mx(',i_k_mx,') =',k_mx(i_k_mx) 
          end do
       end do
    end do

    n_cycles_ksq=5.
    k_cutoff =0.
    delta_k=sqrt(min_prop_cos_sq)
    k_start=delta_k
    k_decay=-1.*delta_k
    if(k_decay**2 > n_cycles_ksq) call seperr('k_decay^2 > n_cycles_ksq')
    edge_gauss = exp(k_decay)
    sigma_magn_sq=-(k_decay-k_start)**2/(2.*log(edge_gauss))

    !write(0,*)'edge_gauss=',edge_gauss
    !write(0,*)'sigma_magn_sq=',sigma_magn_sq

    d_k_norm=n_cycles_ksq/(Pre2D_salt_ssf_par_loc%n_disp_rel-1)
    do i_k_norm=1,Pre2D_salt_ssf_par_loc%n_disp_rel
       k_norm_sq=1.-(i_k_norm-1)*d_k_norm
       if( k_norm_sq > 0.) then
          k_norm=sqrt(k_norm_sq)
       else
          k_norm=-sqrt(-k_norm_sq)
       end if

       phase(i_k_norm)=k_norm

       if(k_norm > k_start) then
          magn(i_k_norm)=0.
       else
          !    gauss_magn=exp(-((k_norm-k_start)**2)/(2.*sigma_magn_sq))
          !    magn(i_k_norm)=log(gauss_magn)
          magn(i_k_norm)=(-((k_norm-k_start)**2)/(2.*sigma_magn_sq))
       end if
    end do
    max_disp_rel=float(Pre2D_salt_ssf_par_loc%n_disp_rel)

    !    write(0,*)'n_disp_rel=',Pre2D_salt_ssf_par_loc%n_disp_rel
    !call snap('Phase.h',Pre2D_salt_ssf_par_loc%n_disp_rel,1,phase)
    !call snap('Magn.h',Pre2D_salt_ssf_par_loc%n_disp_rel,1,magn)
  end function Pre2D_salt_ssf_init

  function Pre2D_salt_ssf_down_oper(adj,add, &
  n_w_data,w_0_data,d_w_data, &
  z_down,d_z_down,data_par, &
  data_slice,ddata_slice,slow_slice) &
  result(stat)

    integer stat,n_w_data
    logical add,adj,found_salt
    real w_0_data,d_w_data,z_down,d_z_down
    type (data_param_type) data_par
    real, dimension (:,:,:,:,:) :: slow_slice
    complex, dimension (:,:,:,:,:) :: data_slice,ddata_slice
    optional data_slice,ddata_slice,slow_slice

    logical prop_data,prop_ddata,var_slow
    integer i_k_mx,i_k_my,i_k_hx,ierr,n_mx,n_my,n_hx,i_mx,i_my,i_hx,i_ksq_norm,i_w
    integer i_slow,n_slow_split_eff,i_surf_comp,surf_comp
    real k0_sq,rphase,rmagn,dphase,w_data
    real slow_0
    real sigma,rat,dscat,dz_w,pi2,norm_ksq,k0,k0_dz

    integer, allocatable,  dimension (:,:,:,:) :: map_slow
    complex, allocatable,  dimension (:,:,:,:) :: temp_data
    complex, allocatable,  dimension (:,:,:,:) :: temp_ddata

    real,  dimension(Pre2D_salt_ssf_par_loc%n_slow_split)      :: slow_prop
    complex,  dimension (data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad) :: ph_shift

    n_mx=data_par%n_mx_pad
    n_my=data_par%n_my_pad
    n_hx=data_par%n_hx_pad
    pi2=2.*acos(-1.)


    prop_data= present(data_slice)
    prop_ddata= present(ddata_slice)
    var_slow= present(slow_slice)


    if(var_slow) then
       allocate (map_slow(data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad,2),&
       stat=ierr)  
       if(ierr .ne. 0) call seperr("Problems with allocation of map_slow ")
    end if
    if(prop_data) then
       allocate(temp_data(data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad,2),&
       stat=ierr)  
       if(ierr .ne. 0) call seperr("Problems with allocation of temp_data")
    end if
    if(prop_ddata) then
       allocate(temp_ddata(data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad,2),&
       stat=ierr)  
       if(ierr .ne. 0) call seperr("Problems with allocation of temp_ddata")
    end if


    if(var_slow) then
       n_slow_split_eff=Pre2D_salt_slow_prop(n_mx,n_my,n_hx,  &
       slow_slice, slow_prop, map_slow) 
    else
       if(Pre2D_salt_ssf_par_loc%n_slow_split /= 1) call seperr('n_slow_split /= 1')
       n_slow_split_eff= 1 
       slow_prop(1)=slow_const
    end if

    do i_w=1,n_w_data
       w_data=w_0_data+(i_w-1)*d_w_data
       do i_surf_comp=1,2
          !{
          if(adj) then
             !{
             if (i_surf_comp == 1) then
                surf_comp=REC_COMP
             else
                surf_comp=SOU_COMP
             end if
             !}
          else
             !{
             if (i_surf_comp == 1) then
                surf_comp=SOU_COMP
             else
                surf_comp=REC_COMP
             end if
             !}
          end if

          ! now downward the sources
          if(prop_data) then
             temp_data(:,:,:,1) = cmplx(0.,0.)
             temp_data(:,:,:,2) = data_slice(:,:,:,1,i_w) 
             data_slice(:,:,:,1,i_w) = cmplx(0.,0.) 
          end if
          if(prop_ddata) then
             temp_ddata(:,:,:,1) = cmplx(0.,0.)
             temp_ddata(:,:,:,2) = ddata_slice(:,:,:,1,i_w) 
             ddata_slice(:,:,:,1,i_w) = cmplx(0.,0.) 
          end if


          if((var_slow) .and. (adj)) then
             if(prop_data) then
                call fft_3d_data(fft_3d_id,FFT_BACK,n_mx,n_my,n_hx,temp_data(:,:,:,2))
             end if
             if(prop_ddata) then
                call fft_3d_data(fft_3d_id,FFT_BACK,n_mx,n_my,n_hx,temp_ddata(:,:,:,2))
             end if
          end if

          do i_slow=1,n_slow_split_eff
             !{
             slow_0=slow_prop(i_slow)

             k0= (w_data*slow_0)
             k0_sq= k0*k0
             if(time_mig) then
                k0_dz=.5*w_data*d_z_down
             else
                k0_dz=k0*d_z_down
                dz_w=d_z_down*w_data
             end if

             if(k0_sq > epsi) then
                norm_ksq=1./(d_k_norm*k0_sq)
             else
                norm_ksq=0.
             end if


             if((var_slow) .and. (adj)) then
                !{
                if(.not. time_mig) then
                   !{
                   do i_hx=1,n_hx
                      !{
                      do i_my=1,n_my
                         !{
                         do i_mx=1,n_mx
                            !{
                            dphase=(slow_slice(i_mx,i_my,i_hx,1,surf_comp)-slow_0)*dz_w
                            ph_shift(i_mx,i_my,i_hx)= cmplx(cos(dphase),sin(dphase))
                            !}
                         end do
                         !}
                      end do
                      !}
                   end do
                   !}
                end if

                if(prop_data) then
                   !{
                   do i_hx=1,n_hx
                      !{
                      do i_my=1,n_my
                         !{
                         do i_mx=1,n_mx
                            !{
                            if(map_slow(i_mx,i_my,i_hx,surf_comp) == i_slow) then 
                               temp_data(i_mx,i_my,i_hx,1)=temp_data(i_mx,i_my,i_hx,2)
                            end if
                            !}
                         end do
                         !}
                      end do
                      !}
                   end do
                   if(.not. time_mig) then
                      !{
                      temp_data(:,:,:,1) = temp_data(:,:,:,1) * ph_shift(:,:,:)
                      !}
                   end if
                   call fft_3d_data(fft_3d_id,FFT_FORW,n_mx,n_my,n_hx,temp_data(:,:,:,1))
                   !} 
                end if
                if(prop_ddata) then
                   !{
                   do i_hx=1,n_hx
                      !{
                      do i_my=1,n_my
                         !{
                         do i_mx=1,n_mx
                            !{
                            if(map_slow(i_mx,i_my,i_hx,surf_comp) == i_slow) then 
                               temp_ddata(i_mx,i_my,i_hx,1)=temp_ddata(i_mx,i_my,i_hx,2)
                            end if
                            !}
                         end do
                         !}
                      end do
                      !}
                   end do
                   if(.not. time_mig) then
                      !{
                      temp_ddata(:,:,:,1) = temp_ddata(:,:,:,1) * ph_shift(:,:,:)
                      !}
                   end if
                   call fft_3d_data(fft_3d_id,FFT_FORW,n_mx,n_my,n_hx,temp_ddata(:,:,:,1))
                   !} 
                end if
                !}
             else
                !{
                if(prop_data) then
                   !{
                   temp_data(:,:,:,1) = temp_data(:,:,:,2)
                   !}
                end if
                if(prop_ddata) then
                   !{
                   temp_ddata(:,:,:,1) = temp_ddata(:,:,:,2)
                   !}
                end if
                !} 
             end if

             do i_k_hx=1,n_hx
                !{
                do i_k_my=1,n_my
                   !{
                   do i_k_mx=1,n_mx
                      !{
                      i_ksq_norm=nint(min(k_sr_sq(i_k_mx,i_k_my,i_k_hx,surf_comp)* &
                      norm_ksq+1,max_disp_rel))
                      rphase=phase(i_ksq_norm)*k0_dz
                      rmagn=abs(k0_dz)*magn(i_ksq_norm)

                      ph_shift(i_k_mx,i_k_my,i_k_hx)=cexp(cmplx(rmagn,rphase))

                      !}
                   end do
                   !}
                end do
                !}
             end do

             if(prop_data) then
                temp_data(:,:,:,1) = temp_data(:,:,:,1) * ph_shift(:,:,:)
             end if
             if(prop_ddata) then
                temp_ddata(:,:,:,1) = temp_ddata(:,:,:,1) * ph_shift(:,:,:)
             end if

             if((var_slow) .and. (.not. adj)) then
                !{
                if(.not. time_mig) then
                   !{
                   do i_hx=1,n_hx
                      !{
                      do i_my=1,n_my
                         !{
                         do i_mx=1,n_mx
                            !{
                            dphase=(slow_slice(i_mx,i_my,i_hx,1,surf_comp)-slow_0)* dz_w
                            ph_shift(i_mx,i_my,i_hx)= cmplx(cos(dphase),sin(dphase))
                            !}
                         end do
                         !}
                      end do
                      !}
                   end do
                   !}
                end if

                if(prop_data) then
                   !{
                   call fft_3d_data(fft_3d_id,FFT_BACK,n_mx,n_my,n_hx,temp_data(:,:,:,1))
                   if(.not. time_mig) then
                      !{
                      temp_data(:,:,:,1) = temp_data(:,:,:,1) * ph_shift(:,:,:)
                      !}
                   end if
                   do i_hx=1,n_hx
                      !{
                      do i_my=1,n_my
                         !{
                         do i_mx=1,n_mx
                            !{
                            if(map_slow(i_mx,i_my,i_hx,surf_comp) == i_slow) then 
                               data_slice(i_mx,i_my,i_hx,1,i_w)=temp_data(i_mx,i_my,i_hx,1)
                            end if
                            !}
                         end do
                         !}
                      end do
                      !}
                   end do
                   !} 
                end if
                if(prop_ddata) then
                   !{
                   call fft_3d_data(fft_3d_id,FFT_BACK,n_mx,n_my,n_hx,temp_ddata(:,:,:,1))
                   if(.not. time_mig) then
                      !{
                      temp_ddata(:,:,:,1) = temp_ddata(:,:,:,1) * ph_shift(:,:,:)
                      !}
                   end if
                   do i_hx=1,n_hx
                      !{
                      do i_my=1,n_my
                         !{
                         do i_mx=1,n_mx
                            !{
                            if(map_slow(i_mx,i_my,i_hx,surf_comp) == i_slow) then 
                               ddata_slice(i_mx,i_my,i_hx,1,i_w)=temp_ddata(i_mx,i_my,i_hx,1)
                            end if
                            !}
                         end do
                         !}
                      end do
                      !}
                   end do
                   !} 
                end if
                !}
             else
                !{
                if(prop_data) then
                   data_slice(:,:,:,1,i_w)=data_slice(:,:,:,1,i_w)+temp_data(:,:,:,1)
                end if
                if(prop_ddata) then
                   ddata_slice(:,:,:,1,i_w)=ddata_slice(:,:,:,1,i_w)+temp_ddata(:,:,:,1)
                end if
                !} 
             end if

             !} sou slowness loop
          end do

          if((var_slow) .and. (.not. adj)) then
             if(prop_data) then
                call fft_3d_data(fft_3d_id,FFT_FORW,n_mx,n_my,n_hx,data_slice(:,:,:,1,i_w))
             end if
             if(prop_ddata) then
                call fft_3d_data(fft_3d_id,FFT_FORW,n_mx,n_my,n_hx,ddata_slice(:,:,:,1,i_w))
             end if
          end if
          !} i_surf_comp loop
       end do
       !} i_w loop
    end do

    if(var_slow) then
       deallocate (map_slow,stat=ierr)  
    end if
    if(prop_data) then
       deallocate (temp_data,stat=ierr)  
    end if
    if(prop_ddata) then
       deallocate (temp_ddata,stat=ierr)  
    end if

    stat=0
    return	
  end function Pre2D_salt_ssf_down_oper

  function Pre2D_salt_ssf_scatter_oper(adj,add, &
  n_w_data,w_0_data,d_w_data, &
  z_down,d_z_down,data_par, &
  data_slice,ddata_slice,dslow_slice,slow_slice) &
  result(stat)

    integer stat,n_w_data
    logical add,adj
    real w_0_data,d_w_data,z_down,d_z_down
    type (data_param_type) data_par
    real, dimension (:,:,:,:,:) :: slow_slice
    complex, dimension (:,:,:,:,:) :: dslow_slice
    complex, dimension (:,:,:,:,:) :: data_slice,ddata_slice
    optional slow_slice

    logical var_slow
    integer i_k_mx,i_k_my,i_k_hx,ierr,n_mx,n_my,n_hx,i_mx,i_my,i_hx,i_w
    integer i_surf_comp,surf_comp,i_slow,n_slow_split_eff
    real slow_0,k0_sq,w_data
    real sigma,rat,dscat,dz_w,k0
    complex im_dz_w,ddscat

    integer, allocatable,  dimension (:,:,:,:) :: map_slow

    real,  dimension(Pre2D_salt_ssf_par_loc%n_slow_split)      :: slow_prop
    complex, dimension (data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad) :: data_scat,ddata_scat

    data_scat=cmplx(0.,0.)
    ddata_scat=cmplx(0.,0.)

    n_mx=data_par%n_mx_pad
    n_my=data_par%n_my_pad
    n_hx=data_par%n_hx_pad

    var_slow= present(slow_slice)

    if(var_slow) then
       allocate (map_slow(data_par%n_mx_pad,data_par%n_my_pad,data_par%n_hx_pad,2),&
       stat=ierr)  
       if(ierr .ne. 0) call seperr("Problems with allocation of map_slow ")
    end if

    if(var_slow) then
       n_slow_split_eff=Pre2D_salt_slow_prop(n_mx,n_my,n_hx,  &
       slow_slice, slow_prop, map_slow) 
    else
       if(Pre2D_salt_ssf_par_loc%n_slow_split /= 1) call seperr('n_slow_split /= 1')
       n_slow_split_eff= 1 
       slow_prop(1)=slow_const
    end if

    if(n_slow_split_eff /= 1) call seperr('Pre2D_salt_ssf_scatter_oper cannot handle n_slow_split_eff >1')

    do i_w=1,n_w_data
       !{
       w_data=w_0_data+(i_w-1)*d_w_data

       do i_surf_comp=1,2
          !{
          if(adj) then
             !{
             if (i_surf_comp == 1) then
                surf_comp=REC_COMP
             else
                surf_comp=SOU_COMP
             end if
             !}
          else
             !{
             if (i_surf_comp == 1) then
                surf_comp=SOU_COMP
             else
                surf_comp=REC_COMP
             end if
             !}
          end if

          do i_slow=1,n_slow_split_eff
             !{
             slow_0=slow_prop(i_slow)

             k0= (w_data*slow_0)
             k0_sq= k0*k0
             dz_w=d_z_down*w_data
             im_dz_w=cmplx(0.,dz_w)

             if(time_mig) then
                call seperr('Scattering for time mig not supported yet!')
             else
                dz_w=d_z_down*w_data
             end if
             im_dz_w=cmplx(0.,dz_w)

             ! compute and apply weights of scattered wavefield in wavenumber domain
             data_scat=data_slice(:,:,:,1,i_w)
             if(Pre2D_salt_ssf_par_loc%lbf_order > 1) then
                !{
                do i_k_hx=1,n_hx
                   !{
                   do i_k_my=1,n_my
                      !{
                      do i_k_mx=1,n_mx
                         !{
                         if((k0_sq - k_sr_sq(i_k_mx,i_k_my,i_k_hx,surf_comp) ) > 0.) then
                            rat=(k_sr_sq(i_k_mx,i_k_my,i_k_hx,surf_comp)/k0_sq)
                            sigma=1+.5*rat + .375*rat*rat
                         else
                            sigma=1
                         end if
                         data_scat(i_k_mx,i_k_my,i_k_hx)=data_scat(i_k_mx,i_k_my,i_k_hx)*sigma
                         !}
                      end do
                      !}
                   end do
                   !}
                end do
                !}
             end if

             ! transform in space domain scattered wavefield
             call fft_3d_data(fft_3d_id,FFT_BACK,n_mx,n_my,n_hx,data_scat)

             if(adj) then
                !{
                ! transform in space domain total scattered wavefield
                ddata_scat=ddata_slice(:,:,:,1,i_w)
                call fft_3d_data(fft_3d_id,FFT_BACK,n_mx,n_my,n_hx,ddata_scat)
                !}
             end if

             do i_hx=1,n_hx
                !{
                do i_my=1,n_my
                   !{
                   do i_mx=1,n_mx
                      !{
                      if(adj) then
                         !{
                         ! add contribution to slowness perturbations 
                         dslow_slice(i_mx,i_my,i_hx,1,surf_comp) =  &
                         dslow_slice(i_mx,i_my,i_hx,1,surf_comp) + &
                         (im_dz_w*conjg(data_scat(i_mx,i_my,i_hx)))*ddata_scat(i_mx,i_my,i_hx)
                         !}
                      else
                         !{
                         ! compute scatered wavefield
                         ddscat=dslow_slice(i_mx,i_my,i_hx,1,surf_comp)*im_dz_w
                         if(cabs(ddscat) /= 0.) then
                            data_scat(i_mx,i_my,i_hx) = &
                            data_scat(i_mx,i_my,i_hx) * ddscat
                         else
                            data_scat(i_mx,i_my,i_hx) = cmplx(0.,0.)
                         end if
                         !}
                      end if
                      !}
                   end do
                   !}
                end do
                !}
             end do

             if(.not. adj) then
                ! transform scattered wavefield back to wavenumber 
                call fft_3d_data(fft_3d_id,FFT_FORW,n_mx,n_my,n_hx,data_scat)
                ! sum to total scattered wavefield
                ddata_slice(:,:,:,1,i_w) = ddata_slice(:,:,:,1,i_w) + data_scat
             end if
             !} i_slow loop
          end do
          !} i_surf_comp loop
       end do
       !} i_w loop
    end do

    stat=0
    return	
  end function Pre2D_salt_ssf_scatter_oper

  function Pre2D_salt_ssf_clean()  result(stat)
    integer stat,ierr
    stat=0
    if(allocated(phase)) deallocate (phase,stat=ierr)  
    if(ierr /=0) stat=stat+1
    if(allocated(magn)) deallocate (magn,stat=ierr)  
    if(ierr /=0) stat=stat+1
    if(allocated(k_mx)) deallocate (k_mx,stat=ierr)  
    if(ierr /=0) stat=stat+1
    if(allocated(k_my)) deallocate (k_my,stat=ierr)  
    if(ierr /=0) stat=stat+1
    if(allocated(k_hx)) deallocate (k_hx,stat=ierr)  
    if(ierr /=0) stat=stat+1
    if(allocated(k_sr_sq)) deallocate (k_sr_sq,stat=ierr)  
    if(ierr /=0) stat=stat+1

    return	
  end function Pre2D_salt_ssf_clean

  function Pre2D_salt_slow_prop(n_mx,n_my,n_hx,  &
  slow_slice, slow_prop, map_slow) &
  result (n_slow_split_eff)

    integer n_slow_split_eff
    integer n_mx,n_my,n_hx
    integer,  dimension (:,:,:,:), intent(inout) :: map_slow
    real, dimension (:,:,:,:,:), intent(in)      :: slow_slice
    real,  dimension(:), intent(inout)           :: slow_prop

    logical found_salt
    integer i_mx,i_my,i_hx
    integer n_slow_split_no_salt,i_slow
    real min_slow,max_slow,range_slow,delta_slow,mean_slow,inv_delta_slow

    if(slow_salt > 0.) then
       min_slow=minval(slow_slice,mask=(slow_slice > slow_salt))
       found_salt=any(slow_slice <= slow_salt)

       if(found_salt) then
          n_slow_split_no_salt=Pre2D_salt_ssf_par_loc%n_slow_split-1
       else
          n_slow_split_no_salt=Pre2D_salt_ssf_par_loc%n_slow_split
       end if
    else
       found_salt=.false.
       min_slow=minval(slow_slice)
       n_slow_split_no_salt=Pre2D_salt_ssf_par_loc%n_slow_split
    end if

    max_slow=maxval(slow_slice)

    range_slow=max_slow-min_slow
    mean_slow=.5*(max_slow+min_slow)
    if(n_slow_split_no_salt > 1) then
       delta_slow=range_slow/(n_slow_split_no_salt-1)
    else
       delta_slow=range_slow/n_slow_split_no_salt
    end if
    if(delta_slow < Pre2D_salt_ssf_par_loc%min_delta_slow*mean_slow) then
       if(range_slow < Pre2D_salt_ssf_par_loc%min_delta_slow*mean_slow) then
          n_slow_split_no_salt= 1 
          delta_slow=range_slow
          inv_delta_slow=0.
       else
          n_slow_split_no_salt= & 
          min(n_slow_split_no_salt, &
          max(1,int(range_slow/(Pre2D_salt_ssf_par_loc%min_delta_slow*mean_slow))))
          if(n_slow_split_no_salt > 1) then
             delta_slow=range_slow/(n_slow_split_no_salt-1)
          else
             delta_slow=range_slow/n_slow_split_no_salt
          end if
          inv_delta_slow=1./delta_slow
       end if
    else
       inv_delta_slow=1./delta_slow
    end if

    ! set velocities outside salt
    if(n_slow_split_no_salt > 1) then
       do i_slow=1,n_slow_split_no_salt
          slow_prop(i_slow)=min_slow+(i_slow-adj_slow)*delta_slow
!write(0,*)'i_slow=',i_slow,'slow_prop(i_slow)',slow_prop(i_slow)
       end do
    else
       if(slow_const > 0.) then  
          slow_prop(1)=slow_const
       else
          slow_prop(1)=min_slow+(1.-adj_slow)*delta_slow
       end if
    end if

    do i_hx=1,n_hx
       !{
       do i_my=1,n_my
          !{
          do i_mx=1,n_mx
             !{
             map_slow(i_mx,i_my,i_hx,SOU_COMP)= &
             min(n_slow_split_no_salt, &
             int((slow_slice(i_mx,i_my,i_hx,1,SOU_COMP)-min_slow)*inv_delta_slow)+1)
             map_slow(i_mx,i_my,i_hx,REC_COMP)= &
             min(n_slow_split_no_salt, &
             int((slow_slice(i_mx,i_my,i_hx,1,REC_COMP)-min_slow)*inv_delta_slow)+1)
             !}
          end do
          !}
       end do
       !}
    end do

    if(found_salt) then
       n_slow_split_eff= n_slow_split_no_salt + 1
       slow_prop(n_slow_split_eff)=slow_salt
       do i_hx=1,n_hx
          !{
          do i_my=1,n_my
             !{
             do i_mx=1,n_mx
                !{
                if(slow_slice(i_mx,i_my,i_hx,1,SOU_COMP) < slow_salt) then
                   map_slow(i_mx,i_my,i_hx,SOU_COMP)= n_slow_split_eff
                end if
                if(slow_slice(i_mx,i_my,i_hx,1,REC_COMP) < slow_salt) then
                   map_slow(i_mx,i_my,i_hx,REC_COMP)= n_slow_split_eff
                end if
                !}
             end do
             !}
          end do
          !}
       end do
    else
       n_slow_split_eff= n_slow_split_no_salt 
    end if
  end function Pre2D_salt_slow_prop
end module Pre2D_salt_ssf_mod












