module tomo_oper_mod{
use process_types
use model_types
use ray_types
use lint2_scale

integer, private,save    :: tau_domain,approx,new
integer, private,save    :: ndim_mod,nv_mod,nx_mod,ny_mod,n1_up,n2_up
real,  private,save      :: dv_mod,dx_mod,dy_mod,v0_mod,x0_mod,y0_mod,o1_up,d1_up
integer, private,save    :: ndim_ray,nt_ray,n_ray,nref_ray
logical ,private, save   :: test_it
real, private,save       :: t0_ray,d2_up,o2_up
real, private,save       :: dt_ray
real, pointer, dimension (:,:,:,:), private :: traj_ray,rays_up,mod_up
real, pointer, dimension (:,:,:,:), private :: rpar_ray
real, pointer, dimension (:,:,:,:), private :: mod_ray
real, pointer, dimension (:,:), private :: tref_ray
real, pointer, dimension (:,:), private :: ttot_ray,xy1,xy2
real, pointer, dimension (:,:,:), private :: sigma0_op,slow0_op
real,pointer,dimension (:) ::scale1,scale2,hold1,hold2
#integer,pointer,dimension(:,:), private  :: zero_offset

contains
subroutine tomo_oper_init(param,model_param,ray,sigma0,slow0,test){

type (tomo_param_type) param
type (model_param_type) model_param
type (ray_obj_type) ray
real, pointer, dimension (:,:,:) :: sigma0,slow0
#integer, pointer, dimension (:,:) :: zero_o
logical :: test

test_it=test


tau_domain=param%tau_domain
approx=param%approx
new=param%new

ndim_mod=model_param%ndim_mod
nv_mod=model_param%nv_mod
nx_mod=model_param%nx_mod
ny_mod=model_param%ny_mod
v0_mod=model_param%v0_mod
x0_mod=model_param%x0_mod
y0_mod=model_param%y0_mod
dv_mod=model_param%dv_mod
dx_mod=model_param%dx_mod
dy_mod=model_param%dy_mod

ndim_ray=ray%ndim_ray
nt_ray=ray%nt_ray
n_ray=ray%n_ray
nref_ray=ray%nref_ray
t0_ray=ray%t0_ray
dt_ray=ray%dt_ray
o1_up=ray%o1_up
d1_up=ray%d1_up
n1_up=ray%n1_up
n2_up=ray%n2_up
o2_up=ray%o2_up
d2_up=ray%d2_up


allocate(xy2(n1_up,2),xy1(nt_ray,2))
allocate(scale2(n1_up),scale1(nt_ray))
allocate(hold2(n1_up),hold1(nt_ray))

if (.not.associated (traj_ray, ray%traj_ray)) traj_ray => ray%traj_ray
if (.not.associated (rpar_ray, ray%rpar_ray))  rpar_ray => ray%rpar_ray
if (.not.associated (mod_ray, ray%mod_ray)) mod_ray => ray%mod_ray
if (.not.associated (ttot_ray, ray%ttot_ray)) ttot_ray => ray%ttot_ray
if (.not.associated (tref_ray, ray%tref_ray)) tref_ray => ray%tref_ray
     if (.not.associated (rays_up, ray%rays_up)) rays_up => ray%rays_up
if (.not.associated (mod_up, ray%mod_up)) mod_up => ray%mod_up

if (.not.associated (slow0_op, slow0)){
			slow0_op => slow0
}

if(tau_domain == 1) 
  if (.not.associated (sigma0_op, sigma0)) sigma0_op => sigma0



}


function tomo_oper (adj, add, slow, time_ray) result (stat){
logical, intent(in) :: adj,add 
integer :: stat 
real,  dimension (:) :: slow
real,  dimension (:) :: time_ray

stat=tomo_oper_apply (adj, add, slow, time_ray) 
}

function tomo_oper_apply (adj, add, slow, time_ray) result (stat){
logical, intent(in) :: adj,add 
integer :: stat
real,  dimension (nv_mod,nx_mod,ny_mod) :: slow,temp_a,temp_b
real,  dimension (n_ray,nref_ray) :: time_ray

real,  dimension (nv_mod,nx_mod,ny_mod) :: dz,dtau
real,  dimension (nv_mod,nx_mod,ny_mod) :: dsdtau,dsigdtau

integer i_ref,i_ray,it,i_dim,nt_ref,nt_inc,itemp,b(2),e(2),i_up
integer iv,ix,iy,iz_m1,itau_m1,idz_ref,iv_up
real dl,dt_ray_ref,vel_ray,sigma_ray,idt_ray,dt_ray_min,rtemp,tempr1(2),tempr2(2)
real pz_ref,pz_inc,delta_s,dz_ref,dz_ref_max,dz_ref_min
real d_tau,z,zz,wz,z0,vp,vm,vel_m1,tempa,tempb
real d_z0,tau,ttau,wtau,tau0,sp,sm,slow_m1,fract
real coeff1,dz_extra

integer icoord_mod(3),icoord_ray(3),one,iz,tempi
real id_mod(3),delta_mod(3),tt(1)
real id_ray(3),delta_ray(3),o_r(3),o_m(3),angle,tempr
logical :: skip



! zero vectors if necessary
if (.not. add) {
      if (adj) slow=0.
      else time_ray=0.
}
one=(0.-v0_mod)/dv_mod
temp_a=slow
slow=0.;
slow(1:nv_mod-2*one,:,:)=temp_a(one+1:nv_mod-one,:,:)

	

idt_ray=1./dt_ray
dt_ray_min=.01*dt_ray
id_mod(1)=1./dv_mod
id_mod(2)=1./dx_mod
id_mod(3)=1./dy_mod

id_ray(1)=1./dx_mod
id_ray(2)=1./dv_mod
id_ray(3)=1./dy_mod


if (ndim_ray == 2) icoord_mod(3)=1


do i_ref=1,nref_ray{
	#loop over the reflectors
  do i_ray=1,n_ray{
		#loop over the rays
    nt_inc=((tref_ray(i_ray,i_ref)-t0_ray)*idt_ray)+1 #index of reflection
    nt_ref=((ttot_ray(i_ray,i_ref)-tref_ray(i_ray,i_ref))*idt_ray)+1 #index of 
		
                                                                     #total ray

		scale1=0
 	  xy1(:,1)=traj_ray(2,:,i_ray,i_ref)
    xy1(:,2)=traj_ray(1,:,i_ray,i_ref)
    do it=2,nt_inc + nt_ref +2{  #loop over array

      	delta_ray=(traj_ray(:,it  ,i_ray,i_ref)- 
              traj_ray(:,it-1,i_ray,i_ref)) #calculate ray change 
                                                #in position
   
			
			if(all(abs(delta_ray)<.00001)) next
					
     	vel_ray=.5*(mod_ray(1,it  ,i_ray,i_ref) + mod_ray(1,it-1,i_ray,i_ref))

     	if((it /= nt_inc+1) .and. (it /= nt_inc+2)) {
				#if we aren't at the reflection
				dt_ray_ref=dt_ray
			}
     	else{
				#this calculates the time of the ray segment near the reflection
       if(it == nt_inc+1){
         dt_ray_ref=tref_ray(i_ray,i_ref)-(nt_inc-1)*dt_ray + t0_ray}
       else
         dt_ray_ref=ttot_ray(i_ray,i_ref)-tref_ray(i_ray,i_ref) - &
                     (nt_ref-1)*dt_ray + t0_ray
			}
     	scale1(it)=dt_ray_ref*vel_ray #now the length of the ray segment

		}
		call lint2_scale_init(nv_mod,nx_mod,0.,dv_mod,x0_mod,dx_mod,xy1,scale1)
		if(adj) hold1=time_ray(i_ray,i_ref) 
		else hold1=0
    call lint2_scale_lop2(adj,slow,hold1)
#		write(0,*) "tot1",minval(hold1),maxval(hold1),sum(hold1)
    if(.not.adj) time_ray(i_ray,i_ref)+=sum(hold1)
		


		#calculate the location of the reflection

		i_up=(traj_ray(1,nt_inc+1,i_ray,i_ref)-o2_up)/d2_up+1.

		if(i_up<1) i_up=1;
		if(i_up>n2_up) i_up=n2_up
		call srite("rray2.H",traj_ray(:,:,i_ray,i_ref),size(traj_ray(:,:,i_ray,i_ref))*4)
		call srite("rray.H",rays_up(:,:,i_up,i_ref),size(rays_up(:,:,i_up,i_ref))*4)

#		write(0,*) i_ray,angle


		scale2=0
		xy2(:,1)=rays_up(2,:,i_up,i_ref)
		xy2(:,2)=rays_up(1,:,i_up,i_ref)

		do it=2,n1_up{
   		delta_ray=(rays_up(:,it  ,i_up,i_ref)-
      rays_up(:,it-1,i_up,i_ref)) #calculate ray change
                                                #in position
      vel_ray=.5*(mod_up(1,it  ,i_up,i_ref) + mod_up(1,it-1,i_up,i_ref))
			if(rays_up(2,it,i_up,i_ref)< 0.) break
			scale2(it)=-d1_up*vel_ray*2
		}
				
		call lint2_scale_init(nv_mod,nx_mod,0.,dv_mod,x0_mod,dx_mod,xy2,scale2)
		if(adj) hold2=time_ray(i_ray,i_ref)
		else hold2=0
		call lint2_scale_lop2(adj,slow,hold2)
#		write(0,*) "tot2",minval(hold2),maxval(hold2),sum(hold2)
		if(.not.adj) time_ray(i_ray,i_ref)+=sum(hold2)

	}
}
	call srite('aa.H',dz,size(dz)*4)

temp_b=slow
slow=temp_a
slow(one+1:nv_mod-one,:,:)=temp_b(1:nv_mod-2*one,:,:)

stat=0
}


subroutine interpolate_point(adj,slow,time_ray,loc1,loc2,amplitude){
logical :: adj
real,  dimension (nv_mod,nx_mod,ny_mod) :: slow
real :: tempr1(2),tempr2(2),tempa,tempb,loc1(3),loc2(3),time_ray,amplitude,tempr
integer :: b(2),e(2)

 tempr1=(/(loc1(2))/dv_mod+1., (loc1(1)-x0_mod)/dx_mod+1./)
 tempr2=(/(loc2(2))/dv_mod+1., (loc2(1)-x0_mod)/dx_mod+1./)

 b=tempr1;e=tempr2
#	write(0,*) "b e ",b,e,amplitude

if(b(1)==e(1)){
  if(b(2)==e(2)){
#	write(0,*) "in a1"
    if(adj) slow(b(1),b(2),1)+=time_ray*amplitude
    else time_ray+=amplitude*slow(b(1),b(2),1)
  }
  else{
#	write(0,*) "in a2"
    if(e(2)>b(2)) tempr=(ceiling(tempr1(2))-tempr1(2))/(tempr2(2)-tempr1(2))
    else tempr=(ceiling(tempr2(2))-tempr2(2))/(tempr1(2)-tempr2(2))

    if(adj){
      slow(b(1),b(2),1)+=amplitude*time_ray*(1.-tempr)
      slow(e(1),e(2),1)+=amplitude*time_ray*tempr
    }
		else{
   		time_ray+=slow(b(1),b(2),1)*amplitude*(1.-tempr)
   		time_ray+=slow(e(1),e(2),1)*amplitude*tempr
   	}
  }
}
else{
	if(b(2)==e(2)){
	 	if(e(1)>b(1))
       tempr=(ceiling(tempr1(1))-tempr1(1))/(tempr2(1)-tempr1(1))
     else
       tempr=(ceiling(tempr2(1))-tempr2(1))/(tempr1(1)-tempr2(1))

     if(adj){
       slow(b(1),b(2),1)+=amplitude*(1.-tempr)*time_ray
       slow(e(1),e(2),1)+=amplitude*tempr*time_ray
     }
     else{
       time_ray+=slow(b(1),b(2),1)*amplitude*(1.-tempr)
       time_ray+=slow(e(1),e(2),1)*amplitude*tempr
     }
 	}
	else{
		if(e(1)>b(1))
   	 tempa=(ceiling(tempr1(1))-tempr1(1))/(tempr2(1)-tempr1(1))
	  else
 	   tempa=(ceiling(tempr2(1))-tempr2(1))/(tempr1(1)-tempr2(1))
	
 	 if(e(2)>b(2))
 	   tempb=(ceiling(tempr1(2))-tempr1(2))/(tempr2(2)-tempr1(2))
 	 else
 	   tempb=(ceiling(tempr2(2))-tempr2(2))/(tempr1(2)-tempr2(2))


	if(tempa>tempb){ #we are heading more down than over
#	write(0,*) "in b1"
    if(adj) {
      slow(b(1),b(2),1)+=amplitude*(1.-tempa)*time_ray
      slow(e(1),b(2),1)+=amplitude*(tempa-tempb)*time_ray
 	    slow(e(1),e(2),1)+=amplitude*tempb*time_ray
    }
    else{
      time_ray+=slow(b(1),b(2),1)*amplitude*(1.-tempa)
      time_ray+=slow(e(1),b(2),1)*amplitude*(tempa-tempb)
      time_ray+=slow(e(1),e(2),1)*amplitude*tempb
    }
	}
  else{
    if(adj) {
      slow(b(1),b(2),1)+=amplitude*(1.-tempb)*time_ray
      slow(b(1),e(2),1)+=amplitude*(tempb-tempa)*time_ray
      slow(e(1),e(2),1)+=amplitude*tempa*time_ray
    }
    else{
      time_ray+=slow(b(1),b(2),1)*amplitude*(1.-tempb)
      time_ray+=slow(b(1),e(2),1)*amplitude*(tempb-tempa)
      time_ray+=slow(e(1),e(2),1)*amplitude*tempa
    }
  }
}
}
	write(0,*) "aa",b,amplitude,time_ray,slow(b(1),b(2),1)

}



}
