@PROBLEM petsc_user_func
@INCLUDE <math.h>
@FUNCTION netsolve_iterative_solve
@DASHI $(MPI_INCLUDE_DIR)
@DASHI $(NETSOLVE_ROOT)/src/SampleNumericalSoftware/SparseSolvers/sparse
@DASHI $(NETSOLVE_ROOT)/src/SampleNumericalSoftware/SparseSolvers/sparse/aux
@DASHI $(NETSOLVE_ROOT)/src/SampleNumericalSoftware/SparseSolvers/sparse/driver
@DASHI $(PETSC_DIR)
@DASHI $(PETSC_DIR)/bmake/linux
@DASHI $(PETSC_DIR)/include
@INCLUDE "petsc.h"
@INCLUDE "mpi.h"
@INCLUDE "parallel_auxs.h"
@INCLUDE "matrix_auxs.h"
@INCLUDE "petsc/netsolve_petsc.h"
@INCLUDE "petsc/petsc_auxs.h"
@LIB -lm
@LIB -L$(NETSOLVE_ROOT)/lib/$(NETSOLVE_ARCH)
@LIB -lnetsolve_petsc
@LIB -lnetsolve_iterative_auxs
@LIB -lnetsolve_aux
@LIB -lnetsolve_aux_distr
@LIB -lnetsolve_timer
@LIB -lnetsolve_tester
@LIB -L$(PETSC_LIB_DIR)
@LIB -lpetscts
@LIB -lpetscsnes
@LIB -lpetscsles
@LIB -lpetscdm
@LIB -lpetscmat
@LIB -lpetscvec
@LIB -lpetsc
@LIB $(LAPACK_LIB_LINK)
@LIB $(BLAS_LIB_LINK)
@LIB -L$(MPI_DIR)/lib
@LIB -lmpich
@LIB -ldl
@LIB -lc
@LIB -lm

@FUNCTION netsolve_iterative_solve
@LANGUAGE C
@MAJOR ROW
@PATH /PETSc-Aztec/
@COMPLEXITY 3,2
@CUSTOMIZED ITER_SOLVE
@DESCRIPTION 
netsolve iterative solve using petsc and aztec
@PARALLEL MPI

@INPUT 5
@OBJECT SPARSEMATRIX D sm
the sparse matrix
@OBJECT VECTOR D rhs_vector
the right-hand-side vector
@OBJECT SCALAR D rtol
error value
@OBJECT SCALAR I maxit
maximum number of iterations
@OBJECT UPF test_func
test function

@OUTPUT 2
@OBJECT VECTOR D sol_vector
solution vector
@OBJECT SCALAR I iterations
iterations converged

@CALLINGSEQUENCE
@ARG mI0, nI0, mI1
@ARG fI0
@ARG I0
@ARG iI0
@ARG pI0
@ARG I1
@ARG I2
@ARG I3
@ARG I4
@ARG O0
@ARG O1

@CODE
extern int netsolve_iterative_solve();
extern int petsc_params();
extern int upf4(double anorm,double rnorm,double r0norm,double bnorm,
                     double xnorm,double epsr,double epsa,double epsd,
                     int it,int maxit,
                     int *outcome);

MPI_Comm comm;
int i,j, index;
int ierr, conv, *convs=NULL, flag;
int Argc = 1;
char **Args;
int local_size;
int first;
int mytid, ntids;
int its, *iterations=NULL;
int total_size, last, nnzero, this_size;
int *ptr, *idx;
double *value_arr, *rhs, *sol;
double** result;
iterative_info_block info=0;
MPI_Status* status;

  Args = (char**)calloc(1,sizeof(char*));
  Args[0] = strdup("dummy");
  comm = MPI_COMM_WORLD; 
  MPI_Comm_size(comm,&ntids);
  MPI_Comm_rank(comm,&mytid);

  ierr = petsc_allocate_info_block(comm,&info);

  ierr = iterative_set_params(info,*@I2@,*@I3@);

  ierr = petsc_params(info,1.e-12,1.e+6,&upf4);

  if(@pI0@[0] == 1)
    ierr = crs_mat_tobase0(@I0@, @iI0@, @pI0@, *@mI0@);

  first = 0;
  local_size = (*@mI0@ > 1) ? *@mI0@ : 1;
  @O0@ = (double*)malloc(local_size*sizeof(double));
  @O1@ = (int*)malloc(sizeof(int));

  /* debug */
   /*printf("package: %d\\n", package);
   printf("local size: %d\\n", local_size);
  printf("values:\\n");
  for(i=0;i<*@fI0@;i++)
   printf("%lf   ",@I0@[i]);
  printf("\\n");
  printf("indices:\\n");
  for(i=0;i<*@fI0@;i++)
   printf("%d   ", @iI0@[i]);
  printf("\\n");
  printf("ptrs:\\n");
  for(i=0;i<*@mI0@+1;i++)
    printf("%d   ", @pI0@[i]);
  printf("\\n");
  printf("rhs:\\n");
  for(i=0;i<*@mI0@;i++)
    printf("%lf   ", @I1@[i]);
  printf("\\n");
    printf("rtol: %lf\\n", *@I2@);
    printf("maxit:%d\\n", *@I3@); */
  /* debug */

   comm = MPI_COMM_WORLD; 
   MPI_Comm_size(comm,&ntids);
   MPI_Comm_rank(comm,&mytid);
  
  if(mytid == 0){
    total_size = *@mI1@;
  }
  MPI_Bcast(&total_size,1,MPI_INT,0,comm);
  ierr = divide(comm,total_size,&first,&last);
  local_size = last - first;
  ierr = crs_nnzeros(@pI0@,total_size,0,&nnzero);
  ptr = (int *) malloc((local_size+1)*sizeof(int));
  idx = (int *) malloc((nnzero+1)*sizeof(int));
  value_arr = (double *) malloc((nnzero+1)*sizeof(double));
  rhs = (double *) malloc((local_size)*sizeof(double));
  sol = (double *) malloc((local_size)*sizeof(double));

  for(i = (@pI0@[first]-@pI0@[0]);i<(@pI0@[last]-@pI0@[0]);i++){
    value_arr[i-(@pI0@[first]-@pI0@[0])] = @I0@[i];
    idx[i-(@pI0@[first]-@pI0@[0])] = @iI0@[i];
  }

  for(i=first;i< last;i++){
    ptr[i-first] = @pI0@[i] - @pI0@[first];
    rhs[i-first] = @I1@[i];
  }
  ptr[last-first] = @pI0@[last] - @pI0@[first];

  /* debug */
    /*  printf("first: %d\\n", first);
      printf("local size: %d\\n", local_size);
      printf("values: \\n");
      for(i=0;i<local_size*total_size;i++)
        printf("%lf   ", value_arr[i]);
      printf("\\n");
      printf("indices:\\n");
      for(i=0;i<local_size*total_size;i++)
        printf("%d  ", idx[i]);
      printf("\\n");
      printf("ptrs:\\n");
      for(i=0;i<local_size+1;i++)
        printf("%d\\t", ptr[i]);
      printf("\\n");
      printf("rhs:\\n");
      for(i=0;i<local_size;i++){
          printf("%lf\\t", rhs[i]);
      }
      printf("\\n"); */


  ierr = petsc_iterative_driver
    (Argc,Args, MPI_COMM_WORLD, first,local_size,
     value_arr,idx,ptr,
     rhs,sol,
     info);

  /* debug */
  /*fprintf(stderr,"sol:\\n");
  for(i=0;i<local_size;i++){
    fprintf(stderr,"%lf\\t", sol[i]);
  }
  printf("\\n"); */
  /* debug */

  ierr = iterative_get_return_params(info,&conv,&its);

  if(mytid == 0){
    @O0@ = (double*)malloc(total_size*sizeof(double));
    @O1@ = (int*)malloc(sizeof(int));
    result = (double**)malloc(sizeof(double*)*ntids);
    status = (MPI_Status*)malloc(sizeof(MPI_Status)*ntids);
    for(i=0;i<ntids;i++)
      result[i] = (double*)malloc(sizeof(double)*total_size);
    iterations = (int*)malloc(sizeof(int)*ntids);
    convs = (int*)malloc(sizeof(int)*ntids);
    *@O1@ = 0;
    *@mO0@ = total_size;

    /* recieve the results from all the other processes */
    for(i=1;i<ntids;i++)
      MPI_Recv(result[i], total_size, MPI_DOUBLE, i, 7, comm, status+i);

    for(index=0;index<local_size;index++){
      @O0@[index] = sol[index];
    }

    for(i=1;i<ntids;i++){
      MPI_Get_count(status+i, MPI_DOUBLE, &this_size);
      for(j=0;j<this_size;j++){
        @O0@[index] = result[i][j];
        index++;
      }
   }
  }
  else{
    MPI_Send(sol, local_size, MPI_DOUBLE, 0, 7, comm);
  }
  MPI_Gather(&its,1, MPI_INT, iterations, 1, MPI_INT, 0, comm);
  MPI_Gather(&conv,1, MPI_INT, convs, 1, MPI_INT, 0, comm);

  if(mytid == 0){
    flag = 0;
    for(i=0;i<ntids;i++){
      if(convs[i]) flag = 1;
    }
    if (flag) {
      *@O1@ = iterations[0];
      printf("Convergence in %d iterations\\n",*@O1@);
    } else {
      printf("No convergence in the specified number of iterations.\\n");
    }
  }

free(info);

@END_CODE
