// Fungimol - an extensible system for designing atomic-scale objects.
// Copyright (C) 2000 Tim Freeman
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
// 
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Library General Public License for more details.
// 
// You should have received a copy of the GNU Library General Public
// License along with this library in the file COPYING.txt; if not,
// write to the Free Software Foundation, Inc., 59 Temple Place -
// Suite 330, Boston, MA 02111-1307, USA
//
// The author can be reached by email at tim@infoscreen.com, or by
// paper mail at:
//
// Tim Freeman
// 655 S. FairOaks Ave., Apt B-316
// Sunnyvale, CA 94086
//

#include "RungaKutta.h"

#ifndef __Float_h__
#include "Float.h"
#endif

#ifndef __Evaluator_h__
#include "Evaluator.h"
#endif

RungaKutta::RungaKutta () {
  m_k1 = 0;
}

RungaKutta::~RungaKutta () {
  if (m_k1) {
    Float j1, j2;
    stop (j1, j2);
  }
}

#ifndef __myassert_h__
#include "myassert.h"
#endif

void RungaKutta::start (Float initialTime,
			Float initialStep) {
  assert (0 == m_k1);
  m_time = initialTime;
  m_step = initialStep;
  assert (m_f);
  m_startedSize = m_f->stateVec().size();
  m_k1 = NEW (Float [m_startedSize]);
  m_k2 = NEW (Float [m_startedSize]);
  m_k3 = NEW (Float [m_startedSize]);
  m_k4 = NEW (Float [m_startedSize]);
  m_tmp = NEW (Float [m_startedSize]);
}

void RungaKutta::setEvaluator (Evaluator *f) {
  m_f = f;
}

Evaluator *RungaKutta::getEvaluator () {
  return m_f;
}

bool RungaKutta::isStarted () {
  if (m_k1) {
    assert (m_k2 && m_k3 && m_k4 && m_tmp);
    return true;
  } else {
    return false;
  }
}

void RungaKutta::stop (Float &time, Float &step) {
  time = m_time;
  step = m_step;
  assert (isStarted());
  delete [] m_k1;
  delete [] m_k2;
  delete [] m_k3;
  delete [] m_k4;
  delete [] m_tmp;
  m_k1 = 0;
}

Float RungaKutta::observe () {
  assert (isStarted ());
  // If an object is deleted in the middle of a timestep, then its
  // derivative will be zero.  This discontinuity doesn't faze
  // RungaKutta because it isn't trying to determine whether it is
  // being accurate or not.  However, it probably could cause an
  // adaptive algorithm to have fits, so let's try catching the
  // exception and doing something constructive with it now.
  for (;;) {
    if (m_f->stateVec().size() != m_startedSize) {
      // The size of the state changed, so we have to reallocate the arrays.
      Float time, step;
      stop (time, step);
      start (time, step);
    }
    Float * state = &(m_f->stateVec()[0]);
    bool retry = m_f->eval (m_time, state, m_k1);
    if (!retry) {
      // FIXME SIMD instructions would be great for each of these loops.
      // Prefetches might help too.
      for (int i = 0; i < m_startedSize; i++) {
	m_tmp [i] = state [i] + m_step * m_k1 [i] / 2;
      }
      retry = m_f->eval (m_time + m_step / 2, m_tmp, m_k2);
    }
    if (!retry) {
      for (int i = 0; i < m_startedSize; i++) {
	m_tmp [i] = state [i] + m_step * m_k2 [i] / 2;
      }
      retry = m_f->eval (m_time + m_step / 2, m_tmp, m_k3);
    }
    if (!retry) {
      for (int i = 0; i < m_startedSize; i++) {
	m_tmp [i] = state [i] + m_step * m_k3 [i];
      }
      retry = m_f->eval (m_time + m_step, m_tmp, m_k4);
    }
    if (!retry) {
      for (int i = 0; i < m_startedSize; i++) {
	state [i] = state [i] +
	  m_step * (m_k1 [i] + 2 * m_k2 [i] + 2 * m_k3 [i] + m_k4 [i]) / 6;
      }
    }
    if (retry) {
      // This isn't right any more, we get here every time objects are deleted.
      // cerr << "One or more objects went out of bounds and have been deleted."
      //<< endl;
    } else {
      break;
    }
  }
  m_time += m_step;
  return m_time;
}

#ifdef TEST
class Circle: public Evaluator {
  void eval (Float t, const Float *state, Float *deriv) {
    deriv [1] = state [0];
    deriv [0] = - state [1];
  }
};

#ifndef __iostream_h__
#include "iostream.h"
#define __iostream_h__
#endif

int main () {
  Integrator * rk = NEW (RungaKutta ());
  Circle c = Circle ();
  const Float pi = 3.141592653589793;
  const int steps = 40;
  Float state [2] = {0, 1};
  rk->start (2, 0.0, 2 * pi / steps, state, &c);
  Float time = 0;
  int i = 0;
  for (;;) {
    cout << "Time: " << time << " x: "<<state[0] << " y: "<<state [1] << endl;
    if (i >= steps) break;
    time = rk->observe ();
    i++;
  }
  delete (rk);
  // The point is that when you integrate all the way around the
  // circle, you should be back where you started. 
  return 0;
}
#endif
