#ifndef SOLVE_H
#define SOLVE_H

#include <iostream.h>
#include <fstream.h>

#include "Array.h"
#include "Permutation.h"
#include "EliminationForest.h"
#include "SparseFactors.h"
#include "Vector.h"
#include "Error.h"
#include "Utilities.h"
#include "Kernels.h"

#undef __CLASS__
#define __CLASS__ "Solve"
template<class T>
class Solve
{
  private:
    const Permutation* p_;
    const EliminationForest* f_;
    const SparseFactors<T>* l_;
    Vector<T>* x_;
    const Vector<T>* b_;

    Solve(const Solve<T>&);
    Solve<T>& operator=(const Solve<T>&);

    void solve2x2(int column, Vector<T>& x);

  public:
    Solve();
    virtual ~Solve();

    void setP(const Permutation& p)
      {p_ = &p;}
    void setF(const EliminationForest& f)
      {f_ = &f;}
    void setL(const SparseFactors<T>& l)
      {l_ = &l;}
    void setX(Vector<T>& x)
      {x_ = &x;}
    void setB(const Vector<T>& b)
      {b_ = &b;}

    void run(void);
};

#undef __FUNC__
#define __FUNC__ "Solve"
template<class T>
Solve<T>::Solve():
  p_(0),
  f_(0),
  l_(0),
  x_(0),
  b_(0)
{
  BEGIN_FUNCTION();

  END_FUNCTION();
}

#undef __FUNC__
#define __FUNC__ "~Solve"
template<class T>
Solve<T>::~Solve()
{
  BEGIN_FUNCTION();

  END_FUNCTION();
}

#undef __FUNC__
#define __FUNC__ "run"
template<class T>
void Solve<T>::run(void)
{
  BEGIN_FUNCTION();

  if (p_ == 0)
    SET_ERROR1(InvalidInput);

  if (f_ == 0)
    SET_ERROR1(InvalidInput);

  if (l_ == 0)
    SET_ERROR1(InvalidInput);

  if (x_ == 0)
    SET_ERROR1(InvalidInput);

  if (b_ == 0)
    SET_ERROR1(InvalidInput);

  if (l_->getOrder() != p_->getOrder())
    SET_ERROR1(InvalidInput);

  if (l_->getOrder() != f_->getOrder())
    SET_ERROR1(InvalidInput);

  if (l_->getOrder() != x_->getOrder())
    SET_ERROR1(InvalidInput);

  if (l_->getOrder() != b_->getOrder())
    SET_ERROR1(InvalidInput);

  int order = l_->getOrder();

  Vector<T> t1(order);//, t2(order);
  CHECK_ERROR1();

  const int* pOldToNewItem = p_->getOldToNew()->getItem();
  const int* pNewToOldItem = p_->getNewToOld()->getItem();
  //const int* fFrontPointerItem = f_->getFrontPointer()->getItem();
  const int* lColumnPointerItem = l_->getColumnPointer()->getItem();
  const int* lRowIndexItem = l_->getRowIndex()->getItem();
  const T* lEntryItem = l_->getEntry()->getItem();
  const int* lPivotTypeItem = l_->getPivotType()->getItem();
  T* xEntryItem = x_->getEntry()->getItem();
  const T* bEntryItem = b_->getEntry()->getItem();
  T* t1EntryItem = t1.getEntry()->getItem();
  //T* t2EntryItem = t2.getEntry()->getItem();

  /* Reorder the right hand side. */
  {for (int k = 0; k < order; k++)
    t1EntryItem[k] = bEntryItem[pNewToOldItem[k]];}

  /* Solve the lower triangular system (forward). */
#define OLD
#ifdef OLD
  /*
  {for (int u = f_->getFirstPostorderNode();
        u != -1;
        u = f_->getNextPostorderNode(u))
    for (int j = fFrontPointerItem[u];
         j < fFrontPointerItem[u + 1];
         j++)
  */
  {for (int j = 0; j < order; j++)
      for (int p = lColumnPointerItem[j] + ((lPivotTypeItem[j] != 2) ? 1 : 2);
           p < lColumnPointerItem[j + 1];
           p++)
        t1EntryItem[lRowIndexItem[p]] -= lEntryItem[p] * t1EntryItem[j];}
#else
  {for (int u = f_->getFirstPostorderNode();
        u != -1;
        u = f_->getNextPostorderNode(u))
  {
    {for (int lp = lColumnPointerItem[fFrontPointerItem[u]], tp = 0;
          lp < lColumnPointerItem[fFrontPointerItem[u] + 1];
          lp++, tp++)
      t2EntryItem[tp] = t1EntryItem[lRowIndexItem[lp]];}
    for (int j = fFrontPointerItem[u],
             offset = 0,
             length = lColumnPointerItem[fFrontPointerItem[u] + 1] -
                      lColumnPointerItem[fFrontPointerItem[u]];
         j < fFrontPointerItem[u + 1];
         j++, offset++, length--)
    {
      int delta = ((lPivotTypeItem[j] != 2) ? 1 : 2),
          adjustedOffset = offset + delta,
          adjustedLength = length - delta;
      T a = -t2EntryItem[offset];
      Axpy(&adjustedLength, &a,
           &lEntryItem[lColumnPointerItem[j] + delta],
           &t2EntryItem[adjustedOffset]);
    }
    {for (int lp = lColumnPointerItem[fFrontPointerItem[u]], tp = 0;
          lp < lColumnPointerItem[fFrontPointerItem[u] + 1];
          lp++, tp++)
      t1EntryItem[lRowIndexItem[lp]] = t2EntryItem[tp];}
  }}
#endif

  /* Solve the diagonal system. */
  /*
  {for (int u = f_->getFirstPostorderNode();
        u != -1;
        u = f_->getNextPostorderNode(u))
    for (int k = fFrontPointerItem[u];
         k < fFrontPointerItem[u + 1];)
  */
  {for (int k = 0; k < order;)
      if (lPivotTypeItem[k] == 1)
      {
        t1EntryItem[k] /= lEntryItem[lColumnPointerItem[k]];
        k++;
      }
      else
      {
        solve2x2(k, t1);
        k += 2;
      }}

  /* Solve the upper triangular system (backward). */
#ifdef OLD
  /*
  {for (int u = f_->getFirstPreorderNode();
        u != -1;
        u = f_->getNextPreorderNode(u))
    for (int i = fFrontPointerItem[u + 1] - 1;
         i >= fFrontPointerItem[u];
         i--)
  */
  {for (int i = order - 1; i >= 0; i--)
      for (int p = lColumnPointerItem[i] + ((lPivotTypeItem[i] != 2) ? 1 : 2);
           p < lColumnPointerItem[i + 1];
           p++)
        t1EntryItem[i] -= lEntryItem[p] * t1EntryItem[lRowIndexItem[p]];}
#else
  {for (int u = f_->getFirstPreorderNode();
        u != -1;
        u = f_->getNextPreorderNode(u))
  {
    {for (int lp = lColumnPointerItem[fFrontPointerItem[u]], tp = 0;
          lp < lColumnPointerItem[fFrontPointerItem[u] + 1];
          lp++, tp++)
      t2EntryItem[tp] = t1EntryItem[lRowIndexItem[lp]];}
    for (int i = fFrontPointerItem[u + 1] - 1,
             offset = fFrontPointerItem[u + 1] -
                      fFrontPointerItem[u] - 1,
             length = lColumnPointerItem[fFrontPointerItem[u] + 1] -
                      lColumnPointerItem[fFrontPointerItem[u]] - offset;
         i >= fFrontPointerItem[u];
         i--, offset--, length++)
    {
      int delta = ((lPivotTypeItem[i] != 2) ? 1 : 2),
          adjustedOffset = offset + delta,
          adjustedLength = length - delta;
      T a = 0;
      Dot(&adjustedLength, &a,
          &lEntryItem[lColumnPointerItem[i] + delta],
          &t2EntryItem[adjustedOffset]);
      t2EntryItem[offset] -= a;
    }
    {for (int lp = lColumnPointerItem[fFrontPointerItem[u]], tp = 0;
          lp < lColumnPointerItem[fFrontPointerItem[u] + 1];
          lp++, tp++)
      t1EntryItem[lRowIndexItem[lp]] = t2EntryItem[tp];}
  }}
#endif

  /* Reorder the solution. */
  {for (int k = 0; k < order; k++)
    xEntryItem[k] = t1EntryItem[pOldToNewItem[k]];}

  END_FUNCTION();
}

#undef __FUNC__
#define __FUNC__ "solve2x2"
template<class T>
void Solve<T>::solve2x2(int column, Vector<T>& x)
{
  BEGIN_FUNCTION();

  const int *lColumnPointerItem = l_->getColumnPointer()->getItem();
  const T *lEntryItem = l_->getEntry()->getItem();
  T *xEntryItem = x.getEntry()->getItem();

  /* Initialize a 2x2 system. */
  T a11 = lEntryItem[lColumnPointerItem[column]],
    a21 = lEntryItem[lColumnPointerItem[column] + 1],
#ifndef HERMITIAN
    a12 = a21,
#else
    a12 = Conj(a21),
#endif
    a22 = lEntryItem[lColumnPointerItem[column + 1]],
    b1 = xEntryItem[column],
    b2 = xEntryItem[column + 1];

  /* Perform partial pivoting if necessary. */
  if (Abs(a11) < Abs(a21))
  {
    Swap(a11, a21);
    Swap(a12, a22);
    Swap(b1, b2);
  }

  /* Compute L and U. */
  T l21 = a21 / a11,
    u11 = a11,
    u12 = a12,
    u22 = a22 - l21 * a12;

  /* Forward solve. */
  T y1 = b1,
    y2 = b2 - l21 * y1;

  /* Backward solve. */
  T x2 = y2 / u22,
    x1 = (y1 - u12 * x2) / u11;

  /* Return the solution of the 2x2 system. */
  xEntryItem[column] = x1;
  xEntryItem[column + 1] = x2;

  END_FUNCTION();
}

#endif
