#include "parpre_pc.h"
#include "./ml_impl.h"
#include "./ml_head.h"

#undef __FUNC__
#define __FUNC__ "SolveSchur"
static int SolveSchur(MC_OneLevel_struct *next_level,
		       SLES a22_solve,Vec rhs,Vec sol,int trace)
{
  int ierr;
  PetscFunctionBegin;
  if (next_level) {
    ierr = SolveMultiLevel(next_level,rhs,sol,trace); CHKERRQ(ierr);
  } else if (a22_solve) {
    Scalar zero=0.;
    int its;
    /* REMOVE after Barry patches guess_zero */
    ierr = VecSet(&zero,sol); CHKERRQ(ierr);
    ierr = SLESSolve(a22_solve,rhs,sol,&its); CHKERRQ(ierr);
  } else SETERRQ(1,0,"Empty deepest level should have been caught");
  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "Solve2Block"
static int Solve2Block(MC_OneLevel_struct *next_level,
		       SLES a22_solve,Mat a22,int degree,Scalar *coffs,
		       Vec rhs,Vec sol,Vec t,int trace)
{
  int ierr;
  PetscFunctionBegin;
  if (degree==1) {
    ierr = SolveSchur(next_level,a22_solve,rhs,sol,trace); CHKERRQ(ierr);
  } else if (degree>0) {
    int d;
    ierr = VecCopy(rhs,t); CHKERRQ(ierr);
    if (degree%2==0) {
      ierr = VecScale(coffs+degree,t); CHKERRQ(ierr);}
    for (d=degree-1; d>0; d--) {
      ierr = SolveSchur(next_level,a22_solve,t,sol,trace); CHKERRQ(ierr);
      ierr = MatMult(a22,sol,t); CHKERRQ(ierr);
      ierr = VecAXPY(coffs+d,rhs,t); CHKERRQ(ierr);
    }
    ierr = SolveSchur(next_level,a22_solve,t,sol,trace); CHKERRQ(ierr);
  } else SETERRQ(1,0,"Non-positive degree encountered\n");

  PetscFunctionReturn(0);
}

#undef __FUNC__
#define __FUNC__ "SolveThisLevel"
static int SolveThisLevel(MC_OneLevel_struct *this_level,Vec x,Vec y,int trace)
{
  Vec g1 = this_level->g1,h1 = this_level->h1,
    g2 = this_level->g2,h2 = this_level->h2,k2 = this_level->k2;
  VecScatter put_clr = this_level->put_clr,get_clr = this_level->get_clr,
    put_rest = this_level->put_rest,get_rest = this_level->get_rest;
  Scalar zero = 0., mone = -1.0;
  int its,ierr;

  PetscFunctionBegin;
  if (! ((int)(this_level->a22_solve) || (int)(this_level->next_level)))
    SETERRQ(1,0,"This cannot happen: end in 1-block");
  
  /* solve (1,1) block into h1; if ILU write back into y */
  ierr = VecScatterFor(x,g1,INSERT_VALUES,get_clr);
  ierr = VecSet(&zero,h1); CHKERRQ(ierr);
  ierr = SLESSolve(this_level->a11_solve,g1,h1,&its); CHKERRQ(ierr);
  if (this_level->solve_scheme==AMLSolveILU) {
    VecScatterFor(h1,y, INSERT_VALUES,put_clr);
  }

  /* make next level rhs by extracting and subtracting A21A11invX1 */
  VecScatterFor(x,this_level->v2, INSERT_VALUES,get_rest);
  if (this_level->transfer) {
    /* g1 -> h1 for separate A11 and A21 */
    ierr = MatMultTranspose(this_level->b12,g1,g2); CHKERRQ(ierr);
    ierr = VecAXPY(&mone,g2,h2); CHKERRQ(ierr);
  }

  /* solve next level and write back to global vector */
  ierr = Solve2Block(this_level->next_level,
		     this_level->a22_solve,this_level->g22,
		     this_level->degree,this_level->coffs,
		     h2,g2,k2,trace);
  CHKERRQ(ierr);
  ierr = VecScatterFor(this_level->u2,y, INSERT_VALUES,put_rest);

  /* transfer to fine grid of previous level */
  if (this_level->transfer) {
    /* separate A11 and A12
    ierr = MatMult(this_level->g12,g2,g1); CHKERRQ(ierr);
    ierr = VecSet(&zero,h1); CHKERRQ(ierr);
    ierr = SLESSolve(this_level->a11_solve,g1,h1,&its); CHKERRQ(ierr);
    */
    ierr = MatMult(this_level->b12,g2,h1); CHKERRQ(ierr);
    ierr = VecScale(&mone,h1); CHKERRQ(ierr);
    if (this_level->solve_scheme==AMLSolveILU) {
      ierr = VecScatterFor(h1,y, ADD_VALUES,put_clr);
    } else {
      ierr = VecScatterFor(h1,y, INSERT_VALUES,put_clr);
    }
  }

  PetscFunctionReturn(0);
}

/* the main recursive solve routine */
#undef __FUNC__
#define __FUNC__ "SolveMultiLevel"
int SolveMultiLevel(MC_OneLevel_struct *this_level,Vec x,Vec y,int trace)
{
  Scalar mone=-1., mtwo=-2., one=1., zero=0.;
  int its,ierr;
  if ((int)this_level->pre_smooth) {
    if (!(this_level->post_smooth)) SETERRQ(1,0,"AMG solve needs post smooth");
    /* REMOVE THIS after Barry patches the guess_zero bug */
    ierr = VecSet(&zero,y); CHKERRQ(ierr); 
    ierr = SLESSolve(this_level->pre_smooth,x,y,&its); CHKERRQ(ierr);

    ierr = MatMult(this_level->mat,y,this_level->u); CHKERRQ(ierr);
    ierr = VecAYPX(&mone,x,this_level->u); CHKERRQ(ierr);

    ierr = SolveThisLevel(this_level,this_level->u,this_level->v,trace);
    CHKERRQ(ierr);
    
    ierr = VecAXPY(&one,this_level->v,y); CHKERRQ(ierr);
    ierr = MatMult(this_level->mat,y,this_level->u); CHKERRQ(ierr);
    ierr = VecAYPX(&mone,x,this_level->u); CHKERRQ(ierr);
    
    ierr = VecSet(&zero,this_level->v); CHKERRQ(ierr);
    ierr = SLESSolve
      (this_level->post_smooth,this_level->u,this_level->v,&its);
    CHKERRQ(ierr);

    ierr = VecAXPY(&one,this_level->v,y); CHKERRQ(ierr);
  } else {
    ierr = SolveThisLevel(this_level,x,y,trace); CHKERRQ(ierr);
  }
  return 0;
}
