#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include "mpi.h"
#include "sparse_globals.h"
#include "itpack.h"
#include "netsolve_itpack.h"
#include "itpack_auxs.h"
#include "itpack_impl.h"
#include "matrix_auxs.h"
#include "parallel_auxs.h"
#include "driver/iterative.h"

int itpack_iterative_solve
(MPI_Comm comm, NetsolveItmeth itmeth, NetsolvePrecond precond,
 double *val,double *rhs,
 int first,int last, int nnzero,int *idx,int *ptr,
 double *sol, 
 double rtol,int maxit, int *conv,int *its
 );

int itpack_iterative_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 iterative_info_block info)
{
  itpack_info_block itpack_info = (itpack_info_block) info;
  int conv,its,nnz,flg,ierr;

  ierr = single_proc(comm,&flg); ERR_RETURN(ierr);
  if (flg!=1) ERR_REPORT("Itpack should not be run in parallel");
  if (first!=0) ERR_REPORTi("First var should be zero; is",first);

  ierr = crs_mat_tobase1(values,indices,pointers,local_size); ERR_RETURN(ierr);
  first++;

  ierr = crs_nnzeros(pointers,local_size,1,&nnz); ERR_RETURN(ierr);

  ierr = itpack_iterative_solve
    (comm, itpack_info->itmeth,itpack_info->precond, values,rhs_vector,
     first,first+local_size,nnz,
     indices,pointers,sol_vector,
     itpack_info->rtol,itpack_info->maxit,&conv,&its); ERR_RETURN(ierr);
  ierr = iterative_set_return_params(info,conv,its); ERR_RETURN(ierr);
  ierr = crs_mat_tobase0(values,indices,pointers,local_size); ERR_RETURN(ierr);

  return 0;
}

int itpack_iterative_solve
(MPI_Comm comm, NetsolveItmeth itmeth, NetsolvePrecond precond,
 double *val,double *rhs,
 int first,int last, int nnzero,int *idx,int *ptr,
 double *sol, 
 double rtol,int maxit, int *conv,int *its
 )
{
  int size=last-1,lwork,*iwork,iparm[12],info,flops;
  double *rwork,rparm[12];

  if (first!=1) ERR_REPORTi("Matrix distributed or strange; first=",first);

  /* parameters */
  dfault(iparm,rparm);
  iparm[0] = maxit;   iparm[1]=0 /*output*/; iparm[3] = 6/*ftn stdout*/;
  iparm[4] = 1/*nonsymm storage*/; iparm[8]=-2/*no redblack*/; iparm[11]=0;
  rparm[0]=rtol;

  /* work space */
  lwork = 8*size+4*iparm[0];
  ALLOCATE(iwork,3*size,int,"integer work space");
  ALLOCATE(rwork,lwork,double,"real work space");

  if (ITMETH_IS(itmeth,NetsolveCG)) {
    if (PRECOND_IS(precond,NetsolveJacobi)) {
      jcg(&size,ptr,idx,val,rhs,sol,
	   iwork,&lwork,rwork,iparm,rparm,
	   &info,&flops);
    } else if (PRECOND_IS(precond,NetsolveSSOR)) {
      ssorcg(&size,ptr,idx,val,rhs,sol,
	      iwork,&lwork,rwork,iparm,rparm,
	      &info,&flops);
    } else if (PRECOND_IS(precond,NetsolveReducedSystem)) {
      rscg(&size,ptr,idx,val,rhs,sol,
	    iwork,&lwork,rwork,iparm,rparm,
	    &info,&flops);
    } else ERR_REPORTi("preconditioner not implemented for CG",precond);
  } else if (ITMETH_IS(itmeth,NetsolveChebychev)) {
    if (PRECOND_IS(precond,NetsolveJacobi)) {
      jsi(&size,ptr,idx,val,rhs,sol,
	   iwork,&lwork,rwork,iparm,rparm,
	   &info,&flops);
    } else if (PRECOND_IS(precond,NetsolveSSOR)) {
      ssorsi(&size,ptr,idx,val,rhs,sol,
	      iwork,&lwork,rwork,iparm,rparm,
	      &info,&flops);
    } else if (PRECOND_IS(precond,NetsolveReducedSystem)) {
      rssi(&size,ptr,idx,val,rhs,sol,
	    iwork,&lwork,rwork,iparm,rparm,
	    &info,&flops);
    } else ERR_REPORTi("preconditioner not implemented for SI",precond);
  } else ERR_REPORTi("iterative method not implemented in itpack",itmeth);

  *its = iparm[0]; *conv = (*its<maxit);

  return 0;
}
  
/*
 * One-shot drivers
 */
int itpack_jcg_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 double rtol,int maxit, iterative_info_block info)
{
  int ierr;

  ierr = itpack_params
    ((itpack_info_block)info,NetsolveCG,NetsolveJacobi); ERR_RETURN(ierr);

  ierr = itpack_iterative_driver
    (MPI_COMM_WORLD,
     values,indices,pointers, first,local_size,
     rhs_vector,sol_vector,
     info); ERR_RETURN(ierr);

  return 0;
}
int itpack_jsi_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 double rtol,int maxit, iterative_info_block info)
{
  int ierr;

  ierr = itpack_params
    ((itpack_info_block)info,NetsolveChebychev,NetsolveJacobi); ERR_RETURN(ierr);

  ierr = itpack_iterative_driver
    (MPI_COMM_WORLD,
     values,indices,pointers, first,local_size,
     rhs_vector,sol_vector,
     info); ERR_RETURN(ierr);

  return 0;
}
int itpack_ssorcg_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 double rtol,int maxit, iterative_info_block info)
{
  int ierr;

  ierr = itpack_params
    ((itpack_info_block)info,NetsolveCG,NetsolveSSOR); ERR_RETURN(ierr);

  ierr = itpack_iterative_driver
    (MPI_COMM_WORLD,
     values,indices,pointers, first,local_size,
     rhs_vector,sol_vector,
     info); ERR_RETURN(ierr);

  return 0;
}
int itpack_ssorsi_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 double rtol,int maxit, iterative_info_block info)
{
  int ierr;

  ierr = itpack_params
    ((itpack_info_block)info,NetsolveChebychev,NetsolveSSOR); ERR_RETURN(ierr);

  ierr = itpack_iterative_driver
    (MPI_COMM_WORLD,
     values,indices,pointers, first,local_size,
     rhs_vector,sol_vector,
     info); ERR_RETURN(ierr);

  return 0;
}
int itpack_rscg_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 double rtol,int maxit, iterative_info_block info)
{
  int ierr;

  ierr = itpack_params
    ((itpack_info_block)info,NetsolveCG,NetsolveReducedSystem); ERR_RETURN(ierr);

  ierr = itpack_iterative_driver
    (MPI_COMM_WORLD,
     values,indices,pointers, first,local_size,
     rhs_vector,sol_vector,
     info); ERR_RETURN(ierr);

  return 0;
}
int itpack_rssi_driver
(MPI_Comm comm,
 double *values,int *indices,int *pointers,
 int first,int local_size,
 double *rhs_vector,double *sol_vector,
 double rtol,int maxit, iterative_info_block info)
{
  int ierr;

  ierr = itpack_params
    ((itpack_info_block)info,NetsolveChebychev,NetsolveReducedSystem); ERR_RETURN(ierr);

  ierr = itpack_iterative_driver
    (MPI_COMM_WORLD,
     values,indices,pointers, first,local_size,
     rhs_vector,sol_vector,
     info); ERR_RETURN(ierr);

  return 0;
}
