package odeToJava;

import odeToJava.modules.*;
import java.io.*;

/*
   class does any constant Runge Kutta scheme for any ODE (no interpolation)
*/
public class Erk
{
    // static methods (used as the interface to solver)

    public static void erk(ODE function, Span tspan, double[] x0, double h, Btableau butcher, String fileName, String stats)
    {
	Erk erk = new Erk(function, tspan, x0, h, butcher, fileName, stats);   // initialize the ERK object

	erk.setNPoints(1000);   // amount of points written to file defaults 1000

	erk.routine();   // run the routine
    }

    public static void erk(ODE function, Span tspan, double[] x0, double h, Btableau butcher, String fileName, String stats, int nPoints)
    {
	Erk erk = new Erk(function, tspan, x0, h, butcher, fileName, stats);   // initialize the ERK object

	erk.setNPoints(nPoints);   // amount of points written to file set to value user specifies

	erk.routine();   // run the routine
    }

    // helper methods (for static methods)

    public void setNPoints(int nPoints)
    {
	this.nPoints = nPoints;
    }

     // constructors

     /*
        constructor sets up the class to do a Runge-Kutta scheme given the ODE,
        an interval of (temporal) integration, an initial value, a (constant)
        stepsize, a Butcher tableau, and a few strings for special features
     */
     public Erk(ODE function, Span tspan, double[] x0, double h, Btableau butcher, String fileName, String stats)
     {
          // span check

          if(!tspan.get_property())   // if the span is out of order halt immediately
          {   // with message
               System.out.println("Improper span: times are out of order");
               System.exit(0);   // halt
          }

          // general initializations / calculations

          this.f = function;   // store the function		
          this.t0 = tspan.get_t0();   // store the time span as initial and final
          this.tf = tspan.get_tf();   // points (for now)
          this.x0 = x0;   // store the initial value
          this.h = h;   // store the stepsize (which will be constant)
          this.s = butcher.getbl();   // store how many stages this Runge-Kutta
             // scheme will execute in
		
          this.a = new double[butcher.getal()][butcher.getal()];   // initialize
          this.b = new double[butcher.getbl()];   // a,b and c of the Butcher
          this.c = new double[butcher.getcl()];   // tableau for the calculate
             // method using arrays of length specified by the Butcher tableau
             // passed to it

          StdMet.matrixcpy(this.a, butcher.get_a());   // fill these a,b and c arrays
          StdMet.arraycpy(this.b, butcher.get_b());   // using Butcher tableau passed to 
          StdMet.arraycpy(this.c, butcher.get_c());  // constructor

          this.FSALenabled = butcher.get_FSALenabled();   // get from the Butcher
             // tableau whether the scheme is first same as last or not

          // calculations

          if(h <= 0)   // error testing for h (must be greater than zero)
          {
               System.out.println("Stepsize must be greater than zero");
               System.exit(0);
          }

          if((tf - t0) < h)   // test to see if stepsize is smaller than time span
          {
               System.out.println("Stepsize is larger than tspan");
               System.exit(0);   
          }
	
          this.n = x0.length;   // dimension of ODE
          this.steps = (long)(Math.floor((tf - t0)/h));

          // initializations for miscellaneous special features

          if(stats.equals("Stats_On"))   // the statistics feature
          {
               this.stats_on = true;
               this.stats_intermediate = false;
          }
          else
               if(stats.equals("Stats_Intermediate"))
               {
                    this.stats_on = true;
                    this.stats_intermediate = true;
               }
               else
                    if(stats.equals("Stats_Off"))
                    {
                         this.stats_on = false;
                         this.stats_intermediate = false;
                    }
                    else
                    {
                         System.out.println("String parameter must be either: 1) \"Stats_On\" 2) \"Stats_Intermediate\" or 3) \"Stats_Off\"");
                         System.exit(0);
                    }

	  this.fileName = fileName;   // the file writing feature

          // output warnings for cases that are not worth stopping the program for

          double temp = (tf - t0)/ h;   // if the integration interval is not an integer
          if(temp != Math.floor(temp))   // multiple of the stepsize, solver cannot possibly
          {   // step right to tf, so we notify user of this
               System.out.println();
               System.out.println("integration interval is not integer multiple of stepsize:");
               System.out.println("problem will only be integrated to " + Math.floor(temp));
               System.out.println();
          }

          if(tspan.get_timesLength() > 0)   // if user has entered in a time span that
          {   // suggests interpolation, notify user that interpolation will not be done
               System.out.println();
               System.out.println("note that this solver does not do interpolation . . .");
               System.out.println();
          }
     }

     // methods
   
     /*
        method computes the solution to the ODE depending on parameters and calculations
	given to and done by the constructor
     */
     public void routine()
    {
          // beginning message

          System.out.println();
          System.out.println("Begin Explicit Runge-Kutta Routine . . .");
          System.out.println();   // leave a space at start (for screen output)

          // intitializations

          count = 0;   // initialize count to 0
          told = t0;   // initialize told to t0 (the starting time)
		
          xold = new double[n];   // initialize the arrays xold, and xnew, two
          xnew = new double[n];   // arrays that will represent each of the
             // arrays of the solution as it integrates
		
          K = new double[s][n];   // a matrix of K values (s rows of size n)
		
          // initialize some temporary variables to compute solution of ODE
		
          double[] sigma = new double[n];   // the sum for K matrxi
          double[] sigma2 = new double[n];   // the loop for xnew
          double[] as1 = new double[n];   // temporary variable for an array sum
          double[] stam1 = new double[n];   // temporary variable for a scalar*array
          double[] fp = new double[n];   // temporary variable to store function evaluation
          double norm;   // used to test to see if solver has gone unstable
		
          StdMet.arraycpy(xold, x0);   // pass x0 to xold (initial value)
		
          firstStep = true;   // we will be starting first step soon
          lastStep = false;   // integration not done nor at last step
          done = false;

          // open a writer the file containing solution at each step

	  ODEFileWriter writer = new ODEFileWriter();
	  writer.openFile(fileName);

	      // the main loop

	      while(!done)   // outer loop integrates the ODE
	      {
		  if(lastStep)   // if loop is on last step then integration will be done after
                         done = true;   // this step

                    /*
                       this loop calculates each row i in the K matrix using the
                       Butcher tableau, its inner loop (the sum), and function
                       evaluations
                    */
                    for(int i= 0; i< s; i++)   // loop for the K matrix
                    {
                         /*
                            this loop calculates the ith row of the K matrix
                            using the ith row of the a matrix of the given Butcher tableau
                            and all of the rows of K before it
                         */
                         for(int j= 0; j< i; j++)   // the loop for each row
                         {
                              StdMet.stam(stam1, a[i][j], K[j]);   // a[i][j]*K[j]
                              StdMet.arraysum(sigma, sigma, stam1);  // sigma = sigma + a[i][j]*K[j]
                         }
     
                         if(!((i == 0) && !firstStep && FSALenabled))
                         {
                              StdMet.stam(stam1, h, sigma);  // sigma = sigma*h
                              StdMet.arraysum(as1, xold, stam1);   // as1 = xold + h*stam1
                              fp = f.f(told + h*c[i], as1);   // fp = f(told + h*c[i], as1)
                              StdMet.arraycpy(K[i], fp);   // set ith row of the K matrix to function evaluation
                              StdMet.zero_out(sigma);   // set sigma array to array of zeros
                         }
                         else
                              StdMet.arraycpy(K[0], K[s - 1]);   // else we copy the last row
                                 // from previous step into first row of present step
                    }

                    /*
                       this loop takes the weighted average of all of the rows in the
                       K matrix using the b array of the Butcher tableau
                    */
                    for(int i= 0; i< s; i++)   // loop for xnew
                    {
                         StdMet.stam(stam1, h*b[i], K[i]);   // h*b[i]*K[i]
                         StdMet.arraysum(sigma2, sigma2, stam1);   // sigma2 = sigma2 + h*b[i]*K[i]
                    }

                    StdMet.arraysum(xnew, xold, sigma2);   // xnew = xold + sigma2
                    StdMet.zero_out(sigma2);   // set sigma2 array to array of zeros

                    norm = StdMet.rmsNorm(xnew);   // take norm of xnew

                    if(norm != norm)   // check to see if norm is NaN, if
                    {   // so, something has gone wrong, solver is unstable
                         System.out.println("unstable . . . aborting");

                         writer.closeFile();   // close the writer before halting

                         return;   // halt routine
                    }
                    
		    if(steps <= nPoints)   // if there are less that nPoints points, write
		    {   // every time
			writer.writeToFile(told + h, xnew);
		    }
		    else
		    {
			if(count % (steps/nPoints) == 0)   // output solution (thus far)
			{   // into file (but only allow ~nPoints of these to go
				writer.writeToFile(told + h, xnew);   // in the file as time is a factor)
			}
		    }

                    if(stats_on)   // output statistics (if user has chosen such)
                    {
                         if(!stats_intermediate)   // do not output if only on
                         {   // intermediate statistics mode
                              System.out.println("count: " + count);
                         }
     
                         System.out.println("stepping: " + told + " -> " + (told + h));	

                         if(!stats_intermediate)   // do not output if only on
                         {   // intermediate statistics mode
                              System.out.println("solution = ");
                              StdMet.arrayprt(xnew);
                         }

                         System.out.println();
                    }
				
                    StdMet.arraycpy(xold, xnew);  // set xold to xnew, preparing to put the next
                       // value of the next step in the integration into xnew
	
                    told += h;  // update told for the next step of the integration
                    count++;   // increment count

                    firstStep = false;   // first step is over

                    if((told + 2.0*h) > tf)
                         lastStep = true;
               }

               System.out.println("done"); 
               System.out.println("final t = " + told);   // output final time
               System.out.println("final solution =");   // and solution
               StdMet.arrayprt(xold);

               // file stuff

	       writer.writeToFile(told, xold);   // put final time and solution into file

               writer.closeFile();   // now that we are done, close the writer
     }
		
     // instance variables
	
     private ODE f;   // the function of the differential equation
     private double t0;   // the starting time
     private double tf;   // the stopping time
     private double[] x0;   // the initial value
     private double h;   // the stepsize of the integration
     private int s;   // the number of stages of this Runge-Kutta scheme
     private boolean FSALenabled;   // whether first same as last functionality of the
        // scheme (if this scheme has the property to begin with) is enabled

     private double[][] a;   // the matrix a of the given Butcher tableau
     private double[] b;   // the array b of the given Butcher tableau
     private double[] c;   // the array c of the given Butcher tableau
	
     private int n;   // dimension of ODE

     private String fileName;   // name of file solution is written to (each step)
     private boolean stats_on;   // whether to report status at each step
     private boolean stats_intermediate;   // whether to report just a few statistics
   
     private long steps;   // the actual number of steps of integration needed
     private long count;   // counter to count the steps of the integration
     private double told;   // stores the current t value
     private boolean firstStep;   // switch verifies if loop of routine is on first step
     private boolean lastStep;   // switch verifies if loop of routine is on last step
     private boolean done;   // termination switch

     private double[] xold;   // stores the current x value (xold)
     private double[] xnew;   // stores the next x value (xnew)
     private double[][] K;   // matrix of K values (s rows of size n)

     private int nPoints;   // number of points to write to file
}
