/*
 *
	computeModPk.cpp
	
	by Mark C. Wyman, Perimeter Institute for Theoretical Physics.
	
	v1, released March 8, 2009.
 
	Computes the power spectrum at z=0 from the power spectra at z=19.9 and z=20.0 in the modified gravity model under study 

	Requires two text files containing your theory's linear power spectra at those two redshifts.  Requires you to 
	hard code in the rc and alpha values of the theory you wish to model. "She may not look like much, but she's got it where it counts."
	
	The nonlinear correction applied is that for
	the algorithm of R.E. Smith et al, [The Virgo Consortium Collaboration],
   ``Stable clustering, the halo model and nonlinear cosmological power spectra,''
	Mon. Not. Roy. Astron. Soc.  {\bf 341}, 1311 (2003)
	 [arXiv:astro-ph/0207664]. 
  
	 The version of the halofitting algorithm which we began with is that of Martin Kilbinger,
	 http://www.astro.uni-bonn.de/~kilbinge/cosmology/smith2/readme ;
	 
	 If used, please cite J. Khoury and M. Wyman, arXiv:0903.1292
	
	 It outputs a file named "Pknow.txt" with three columns: 
	 
	 the wavenumber, the linear power spectrum, and the nonlinearly corrected power spectrum.
	 
	 In its present form, the code requires number of lines in the linear power spectrum to be 
	 hand coded in as Nmodes. Note that the iterative bisection necessary to locate 
	 the nonlinear scale requires a goodly number of subdivisions in wavenumber, k, to converge.
	 So don't be shy about asking CAMB or your favorite linear P(k) generator to give you a 
	 power spectrum densely sampled in wavenumber. 
	
	 Note that the change to the Smith et al parameters has been made via prefactor multiplication
	 in the subroutine halofit. This is intended to make it trivially easy to remove the modifications,
	 either because you want to test our version of halofit or just like our version of the code.
	 
	 To recover the results of standard gravity entirely, you must either take out all the modified stuff, or -- 
	 and this is easier -- simply make rc >> h/H0. Note that simply making rc very large will NOT
	 automatically return the Halofitting parameters to their standard gravity values.
 */

#include <math.h>
#include <fstream>
#include <iostream>
#include <stdlib.h>
#include <vector>

//#include <malloc.h>
#include <stdio.h>
#include <assert.h>

//#include "smith2.h"


#define pi     3.14159265358979323846
#define pi_sqr 9.86960440108935861883
#define twopi  6.28318530717958647693
#define ln2    0.69314718
#define arcmin 2.90888208665721580e-4

#define epsilon  1e-5
#define epsilon2 1e-18

#define DSQR(a) ( (darg=(a))==0.0 ? 0.0 : darg*darg )



using namespace std;

const double PGROW= -0.20;
const double PSHRNK = -0.25;
const double FCOR = 0.0666666666667;
const double SAFETY = 0.9;
const double ERRCON = 6.0e-4;
const double hmax = 0.2;
const double TINY = 1.0e-30;

// MODIFIED GRAVITY MODEL PARAMETERS
const double rc = 500.; // in Mpc h^-1
const double alpha = 0.0; // 0.5 for DGP; 0.0 for cascading gravity / degravitation

const double H0 =1/2998.; // Hubble/c in Mpc^-1
const double h = 0.7;

const double OmR = 0.000;
const double OmM = 0.3;
//const double OmL = 1- OmM - OmR

const int Nmodes = 4345;

double Hubble(double a);
double g(double a);
double f(double a, double k);
double beta(double a);
double growth(double a);
double dlog(double x);
double eq1(double a, double k, double capD, double del);
double integrate(vector<double> F, vector<double> k);
void NL(vector<double> P, vector<double> k, vector<double> & PNL);
void rk4(double delin, double Din, double & delout, double & Dout, double k, double N, double h);
void halofit(double rk, double rn, double rncur, double rknl, double plin,double om_m, double om_v, double & pnl);
void rkqc(vector<double> & y, vector<double> & yscal, vector<double> & dy, double & N, double htry, double & hdid, double & hnext, double eps);



int main()
{
	double anow, aprev, znow = 49.0,Hnow,dN,N,daft,dNfirst;
	double htry, hdid, hnext,eps=0.0001,factor;
	int count,j;
	ifstream p_in,p_in2;
	ofstream p_out,grout;
	p_in.open("z49p1.txt");
	p_in2.open("z49pk.txt");
	p_out.open("Pknow.txt");
	grout.open("growthrate.txt");
	
	
	anow = 1./(1.+znow);
	aprev = 1./(1.+49.1);
	Hnow=Hubble(anow);
	htry = 0.0001; // fiducial step size in units of efolding time N = H dt
	N = log(anow);
	
  dNfirst = log(anow) - log(aprev);
  vector<double> Pk(Nmodes), delta(Nmodes), k(Nmodes),Pk2(Nmodes),deldot(Nmodes),PNL(Nmodes),DL(Nmodes);  
  vector<double> y(3*Nmodes),dy(3*Nmodes),yscal(2*Nmodes);
  

	
  for(int i=0;i<Nmodes;i++)
	{
		p_in >> k[i] >> Pk[i];
		p_in2 >> daft >> Pk2[i];	
	}
	for(int i=0;i<Nmodes;i++)
	{
		delta[i] = sqrt(k[i]*k[i]*k[i]*Pk2[i]/(2.*3.1415*3.1415));
		deldot[i] = (delta[i] - sqrt(k[i]*k[i]*k[i]*Pk[i]/(2.*3.1415*3.1415)))/dNfirst;
	}
	
			count =0;

	for(j=0;j<Nmodes;j++)
		{
			yscal[j]=delta[j];
			yscal[j+Nmodes]=deldot[j];
			y[j+2*Nmodes] = k[j];
		}
	for(int i=0;i<Nmodes;i++)
		{
			y[i] = delta[i];
			y[i+Nmodes] = deldot[i];
		}
	
	for(;;)
	{
	count += 1;
	 N = log(anow);
		rkqc(y,yscal,dy,N,htry,hdid,hnext,eps);
		htry = hnext;
		anow = exp(N);
		Hnow = Hubble(anow);
		
		if(count%10==0) grout << anow << " " <<  y[10+Nmodes]/y[10] << " " << y[1000+Nmodes]/y[1000] << " " << y[2200+Nmodes]/y[2200] << " "<< y[4000+Nmodes]/y[4000] << " " << H0*h*sqrt(OmM*pow(anow,-3.0))/Hubble(anow) << endl;
		
	 if((N+hnext)>0) 
		{
		dN = -1*N;
		for(int i=0;i<Nmodes;i++)
			{
			delta[i] = y[i];
			deldot[i] = y[i+Nmodes];
			rk4(delta[i],deldot[i],delta[i],deldot[i],k[i],N,dN);
			}
	   N = N + dN;	
		break;
		}
	}
	
		anow = exp(N);
		Hnow = Hubble(anow);
		


	   for(int i=0;i<Nmodes;i++)
	{
		Pk[i] =twopi*pi*pow(delta[i],2.0)/(k[i]*k[i]*k[i]);
		DL[i] =pow(delta[i],2.0); 
	}
	
	NL(DL,k,PNL);
	
	for(int i=0;i<Nmodes;i++)
	{
		PNL[i] = twopi*pi*PNL[i]/(k[i]*k[i]*k[i]);
		p_out << k[i] << " " << Pk[i] << " " << PNL[i] << endl;
	}
	 
  
}

double Hubble(double a)
{
double H;

H = H0*h*sqrt(OmM*pow(a,-3.0)+OmR*pow(a,-4.0)+(1-OmM-OmR));

return H;
}

double g(double a)
{
double gval;

gval = -1./(3.*beta(a));

return gval;
}

double f(double a, double k)
{
double fval;

fval = pow((k*rc/a),2.*(alpha-1.));

return fval;
}

double beta(double a)
{
double betaval;

betaval = 1. + 2.*pow(Hubble(a)*rc,2.*(1-alpha))*(1. + growth(a)/3.);

return betaval;
}

double growth(double a)
{
double growval;

growval = -1.5*(OmM*pow(a,-3.)/(OmM*pow(a,-3.)+1-OmM));

return growval;
}

double eq1(double a, double k, double capD, double del)
{
double eqval,dcrit;

dcrit = 3000*pow(a,3);

if(del<dcrit){
eqval = -1*(2.+growth(a))*capD + 1.5*(1.-g(a))*OmM*pow(a,-3.)*pow((h*H0)/Hubble(a),2.0)*del/(1+f(a,k));}
else
{eqval = -1*(2.+growth(a))*capD + 1.5*OmM*pow(a,-3.)*pow((h*H0)/Hubble(a),2.0)*del/(1+f(a,k));}

return eqval;
}
void rk4(double delin, double  Din, double & delout, double & Dout, double k, double N, double h)
{
  double k1=0,k2=0,k3=0,k4=0,l1=0,l2=0,l3=0,l4=0, at, Ltemp,lat,vtemp,dla,newla;
  double a, L, v, la, newdel, newD;
  
  double tempN;

	vtemp = Din;
	Ltemp = delin;

  a = exp(N);
  k1 = h*vtemp;
  l1 = h*eq1(a,k,vtemp,Ltemp);
  
  tempN = N + h*0.5;
  at = exp(tempN);
  vtemp = vtemp + 0.5*l1;
  Ltemp = Ltemp + 0.5*k1;
  k2 = h*vtemp;
  l2 = h*eq1(at,k,vtemp,Ltemp);

  vtemp = Din + 0.5*l2;
  Ltemp = delin + 0.5*k2;
  k3 = h*vtemp;
  l3 = h*eq1(at,k,vtemp,Ltemp);

  vtemp = Din + l3;
  Ltemp = delin + k3;
  tempN = N + h;
  at = exp(tempN);
  k4 = h*vtemp;
  l4 = h*eq1(at,k,vtemp,Ltemp);


  newdel = delin+ k1/6.0 + k2/3.0 + k3/3.0 + k4/6.0;
  newD = Din + l1/6.0 + l2/3.0 + l3/3.0 + l4/6.0;
	
	
	delout = newdel;
	Dout = newD;
}
void NL(vector<double> P, vector<double> k, vector<double> & PNL)
{

   const double kNLstern = 1.e6;       /* h/Mpc */
	vector<double> F(Nmodes);
   static double logkmin = 0., logkmax = 0., dk = 0., da = 0.;
   double Delta_NL, Delta_L, k_L, lnk_NL;
	static double upper;

   double	omm, omv, amp,pl,pnonl;
   double logr1, logr2 ,diff, rmid, sig, d1, d2;
   double	rknl, rneff, rncur;
      
   double aa, klog, val, logrmidtmp, logrmid, logr1start, logr2start, ksqr;
   int i,j, iter, golinear;
   const double logstep = 5.0;
   const int itermax    = 20;

   /* find non-linear scale with iterative bisection */
	    logr1 = -2.0;
	    logr2 =  3.5;
		
		iterstart:


	    logr1start = logr1;
	    logr2start = logr2;

	    iter = 0;
	    do
		 {
			logrmid = (logr2+logr1)/2.0;
	       rmid    = pow(10,logrmid);
			 for(i=0;i<Nmodes;i++)
				{
					ksqr = pow(k[i]*rmid,2.0);
					F[i] = P[i]*exp(-ksqr);
				}
					
			sig = integrate(F,k);
			diff = sig - 1.0;
//			cout << rmid << " " << 1./rmid << " " << sig << endl;

	       if(diff>0.001)
		 logr1 = dlog(rmid);
	       if(diff<-0.001)
		 logr2 = dlog(rmid);

	    }while (fabs(diff)>=0.001 && ++iter<itermax);

	    if (iter>=itermax) {
	       logrmidtmp = (logr2start+logr1start)/2.0;

	       if (logrmid<logrmidtmp) {
		  logr1 = logr1start-logstep;
		  logr2 = logrmid;
	       } else if (logrmid>=logrmidtmp) {
		  logr1 = logrmid;
		  logr2 = logr2start+logstep;
	       }

	       if (1/pow(10, logr2)>kNLstern) {
				cout << " Trouble in dodge! " << endl;
		  golinear = 1;
//		  upper = table_slope[i] = n_spec-4.0;
	       } else {
		  goto iterstart;
	       }
	    }
		
		 
		  // find rneff and rncur 
		for(i=0;i<Nmodes;i++)
				{
					ksqr = pow(k[i]*rmid,2.0);
					F[i] = 2.0*ksqr*P[i]*exp(-ksqr);
				}
				
			d1 = -integrate(F,k);			
			rneff = -3 - d1;
			
		for(i=0;i<Nmodes;i++)
				{
					ksqr = pow(k[i]*rmid,2.0);
					F[i] = 4.0*ksqr*(1.0-ksqr)*P[i]*exp(-ksqr);
				}
					
			rncur =  d1*d1 + integrate(F,k);
	
		for(i=0;i<Nmodes;i++)
			{
			pl = P[i];
			k_L = k[i];
			halofit(k_L,rneff,rncur,1.0/rmid,pl,0.3,0.7,pnonl);
			PNL[i] = pnonl;
			}
			
}

double dlog(double x)
{
   return log(x)/log(10.0);
}

double integrate(vector<double> F, vector<double> k)
{

	double sum=0, dlk,dk;
	for(int i=0; i<Nmodes; i++)
		{
		dlk = (k[i] - k[i-1])/(0.5*k[i]+0.5*k[i-1]);
		sum += F[i]*dlk + 0.5*(F[i]-F[i-1])*dlk;
		}
		
	return sum;
}


void halofit(double rk, double rn, double rncur, double rknl, double plin, double om_m, double om_v, double & pnl)
{
   double gam,a,b,c,xmu,xnu,alpha2,beta,f1,f2,f3;
   double y, ysqr;
   double f1a,f2a,f3a,f1b,f2b,f3b,frac,pq,ph;
   double nsqr;


   nsqr = rn*rn;
   gam = 0.86485 + 0.2989*rn + 0.1631*rncur;
   a = 0.84*(1.4861 + 1.83693*rn + 1.67618*nsqr + 0.7940*rn*nsqr
     + 0.1670756*nsqr*nsqr - 0.620695*rncur);
   a = pow(10,a);
   b = 1.1*pow(10,(0.9463+0.9466*rn+0.3084*nsqr-0.940*rncur));
   c = 1.05*pow(10,(-0.2807+0.6669*rn+0.3214*nsqr-0.0793*rncur));
   xmu = 0.8*pow(10,(-3.54419+0.19086*rn));
   xnu = 0.8*pow(10,(0.95897+1.2857*rn));
   alpha2 = 0.8*(1.38848+0.3701*rn-0.1452*nsqr);
   betav = 2.0*(0.8291+0.9854*rn+0.3400*nsqr);

   if(fabs(1-om_m)>0.01) {
      f1a = pow(om_m,(-0.0732));
      f2a = pow(om_m,(-0.1423));
      f3a = pow(om_m,(0.0725));
      f1b = pow(om_m,(-0.0307));
      f2b = pow(om_m,(-0.0585));
      f3b = pow(om_m,(0.0743));
  
      frac = om_v/(1.-om_m);  
      f1 = frac*f1b + (1-frac)*f1a;
      f2 = frac*f2b + (1-frac)*f2a;
      f3 = frac*f3b + (1-frac)*f3a;
   } else {      /* EdS Universe */
      f1 = f2 = f3 = 1.0;
   }

   y = 1.02*rk/rknl;
	ysqr = y*y;
   ph = a*pow(y,f1*3)/(1+b*pow(y,f2)+pow(f3*c*y,3-gam));
   ph = ph/(1+xmu/y+xnu/ysqr);
   pq = plin*pow(1+plin,beta)/(1+plin*alpha2)*exp(-y/4.0-ysqr/8.0);
   
	pnl = pq + ph;

//   assert(finite(*pnl));
}
 
void rkqc(vector<double> & y, vector<double> & yscal, vector<double> & dy, double & N, double htry, double & hdid, double & hnext, double eps)
{
  int i,max;
  vector<double> ysav(3*Nmodes),ytemp(3*Nmodes),dysav(3*Nmodes),v(Nmodes),del(Nmodes),delsav(Nmodes),vsav(Nmodes);
  vector<double> vtemp(Nmodes),deltemp(Nmodes),k(Nmodes),delscal(Nmodes),vscal(Nmodes); 
  double xsav, hh,h,temp, errmax;

  for(i=0;i<3*Nmodes;i++)
    {
      ysav[i] = y[i];
      dysav[i] = dy[i];
    }
	
  for(i=0;i<Nmodes;i++)
		{
		del[i]=y[i];
		v[i] = y[i+Nmodes];
		k[i] = y[i+2*Nmodes];
		deltemp[i] = del[i];
		delsav[i] = del[i];
		vtemp[i] = v[i];
		vsav[i] = v[i];
		delscal[i] = yscal[i];
		vscal[i] = yscal[i+Nmodes];
		}
  

  h = htry;
  for(;;)
    {
      hh = 0.5*h;
		for(i=0;i<Nmodes;i++)
			{
				rk4(delsav[i],vsav[i],deltemp[i],vtemp[i],k[i],N,hh);
				rk4(deltemp[i],vtemp[i],del[i],v[i],k[i],N,hh);
//      if((eta+h) == etaold) cout << "WARNING: Step Size TOO SMALL!" << endl;
      
				rk4(delsav[i],vsav[i],deltemp[i],vtemp[i],k[i],N,h);
			}
		
			errmax = 0.0;
      for(i=0;i<Nmodes;i++)
			{
			deltemp[i] = del[i]-deltemp[i];
			temp= fabs(deltemp[i]/delscal[i]);
			if(errmax<temp) errmax = temp; 
			vtemp[i] = v[i]-vtemp[i];
			temp = fabs(vtemp[i]/vscal[i]);
			if(errmax<temp) errmax = temp; 
			dy[i] = del[i]-delsav[i];
			dy[i+Nmodes] = v[i]-vsav[i];
			}
      errmax /= eps;
      if(errmax <= 1.0)
			{
			hdid = h;
	  
			if(errmax > ERRCON)
				{
					hnext = SAFETY*h*exp(PGROW*log(errmax));
				}
			else
				{
					hnext = 4.0 * h;
				}
			if(hnext > hmax) hnext = hmax;
			break;
			}

			h = SAFETY*h*exp(PSHRNK*log(errmax));
		}
		
  for(i=0;i<Nmodes;i++)
  {
   y[i] = del[i]+  deltemp[i]*FCOR;
	y[i+Nmodes] = v[i] + vtemp[i]*FCOR;
	}
  N += hdid;
}
