#include "rs.h"
#include "gf64k.h"
#include <string.h>

/* Reed-Solomon decoder.
   See Berlekamp, Algebraic Coding Theory, Ch. 10.
   It's much too involved to explain in comments. */

#if (N != Q - 1)
#error header file mismatch -- block size must be |field| - 1
#endif

/* Generator polynomial, monic term omitted. */

#define MAXR 128
static int g[MAXR];

/* Calculate monic generator polynomial with R roots in GF(Q).
   The roots are the first R nonzero powers of alpha. */

static void initg (unsigned R)
{
  int i, j, r;

#ifndef NDEBUG
  if (R > MAXR)
    abort ();
#endif

  memset (g, 0, sizeof g);

  for (i = R; i > 0; i--) {
    r = neg (alg[i]);
    for (j = i; j < R; j++)
      g[j - 1] = add (g[j - 1], mul (g[j], r));
    g[R - 1] = add (g[R - 1], r); }
}

/* Compute checksum for message M of length LEN (at most K),
   place the computed checksum in C (of length R). */

void encode (symbol *m, unsigned K, symbol *c, unsigned R)
{
  int i, j, t;

#ifndef NDEBUG
  if (K + R > N)
    abort ();
#endif

  if (! g[0])
    initg (R);

  /* Divide -M by G, put remainder in C. */

  memset (c, 0, R * sizeof *c);

  for (i = K - 1; i >= 0; i--) {
    t = sub (c[R - 1], m[i]);
    for (j = R - 1; j > 0; j--)
      c[j] = sub (c[j - 1], mul (g[j], t));
    c[0] = neg (mul (g[0], t)); }
}

/* Decode received message M, length K (at most K), received checksum C.
   Correct up to R/2 symbol errors in (M,C).
   Returns the number of corrections, or -1 if no correction
   could be performed (because of too many errors). */

int decode (symbol *m, unsigned K, symbol *c, unsigned R)
{
  return decode_erasures (m, K, c, R, 0, 0, 0);
}

/* Decode received message M, length K, received checksum C,
   with NQ erasures at positions Q[0..NQ-1].

   Erasure positions are given by numbers in 0..Q-2.  (No repetitions.)
   C[0] is position 0, ..., C[R-1] is position R-1.
   M[0] is position R, ..., M[K-1] is position R+K-1 == N-1 == Q-2.
   The values given for erased symbols in (M,C) are unimportant.

   Decoding will succeed if there are at most (R - NQ) / 2 errors.

   Returns the number of corrections, or -1 if correction could not
   be performed (because of too many errors + erasures). */

int decode_erasures (symbol *m, unsigned K, symbol *c, unsigned R,
		     unsigned *q, unsigned nq, unsigned *qout)
{
  int i, k;
  int a, r, t, w;
  int b, d;
  int ne, nx;
  int s[2 * R];
  int sigma[R+1], omega[R+1], tau[R+1], gamma[R+1], delta;
  int x[R];

  /* Abort if block too long or erasure positions aren't in block. */

#ifndef NDEBUG
  if (K + R > N)
    abort ();

  for (k = 0; k < nq; k++)
    if (q[k] >= R + K)
      abort ();
#endif

  /* Fail if too many erasures to be correctable (or checkable). */

  if (nq > R)
    return -1;

  /* Compute power sums. */

  for (k = 0; k < R; k++) {
    a = alg[k + 1];
    t = 0;
    for (i = K - 1; i >= 0; i--)
      t = add (m[i], mul (a, t));
    for (i = R - 1; i >= 0; i--)
      t = add (c[i],  mul (a, t));
    s[k] = t; }

  /* Fast return if no errors. */

  for (k = 0; k < R; k++)
    if (s[k]) goto dijk;
  return 0;
 dijk:

  /* Multiply S by (1 - Xz) for each erasure position X. */
  /* @@ only need high R-2 coefficients, it says here */

  r = R;
  for (i = 0; i < nq; i++) {
    a = alg[q[i]];
    s[r] = 0;
    for (k = r - 1; k >= 0; k--)
      s[k + 1] = sub (s[k + 1], mul (a, s[k]));
    s[0] = sub (s[0], a);
    r++; }

  /* Solve  (1 + S) sigma = omega  for error locator sigma and
     error evaluator omega. */

  memset (sigma, 0, sizeof sigma);
  memset (tau, 0, sizeof tau);
  memset (omega, 0, sizeof omega);
  memset (gamma, 0, sizeof gamma);
  d = b = 0;
  sigma[0] = tau[0] = omega[0] = 1;

  for (k = 0; k < R - nq; k++) {
    delta = 0;
    for (i = 0; i <= k; i++)
      delta = add (delta, mul (sigma[i], s[k - i + nq]));

    if (delta == 0 || 2 * d > k + 1 || (! b && 2 * d == k + 1)) {
      for (i = k; i >= 0; i--) {
	sigma[i+1] = sub (sigma[i+1], mul (delta, tau[i]));
	tau[i+1] = tau[i];
	omega[i+1] = sub (omega[i+1], mul (delta, gamma[i]));
	gamma[i+1] = gamma[i]; }
      tau[0] = gamma[0] = 0; }
    else {
      d = k + 1 - d;
      b = 1 - b;
      for (i = k; i >= 0; i--) {
	sigma[i+1] = sub (sigma[i+1], mul (delta, tau[i]));
	tau[i] = div (sigma[i], delta);
	omega[i+1] = sub (omega[i+1], mul (delta, gamma[i]));
	gamma[i] = div (omega[i], delta); }}}

  /* Number of errors NE is the degree of sigma. */

  ne = d;

  /* Find roots of sigma by trying all possibilities. */

  nx = 0;
  for (a = 1; a < Q && nx < ne; a++) {
    t = sigma[0];
    for (i = 1; i <= ne; i++)
      t = add (sigma[i], mul (a, t));
    if (t == 0) {
      if (lg[a] < R + K)
	x[nx++] = a;
      else
	/* correction is not within block, fail */
	return -1; }}

  /* Check if NE roots were found.  If not, not a correctable pattern. */

  if (nx < ne)
    return -1;

  /* Put erasure info from power sums into error evaluator. */

  if (nq) {
    for (k = ne; k >= 0; k--) {
      omega[k + nq] = omega[k];
      omega[k] = 0; }

    for (k = 0; k <= ne; k++) {
      t = sigma[k];
      omega[k] = add (omega[k], t);
      for (i = 1; i < nq; i++)
	omega[k + i] = add (omega[k + i], mul (t, s[i - 1]));
      omega[k + i] = add (omega[k + i], mul (t, sub (s[i - 1], 1))); }}

  /* Append erasure locations to error location list. */

  for (i = 0; i < nq; i++)
    x[ne++] = alg[q[i]];

  /* For each error location, plug into the error evaluator to find
     the correction. */

  for (k = 0; k < ne; k++) {
    a = x[k];
    w = omega[0];
    for (i = 1; i <= ne; i++)
      w = add (omega[i], mul (a, w));
    w = div (w, a);
    for (i = 0; i < ne; i++)
      if (i != k)
	w = div (w, sub (x[i], x[k]));
    a = lg[a];
    if (qout)
      qout[k] = a;
    if (a < R)
      c[a] = sub (c[a], w);
    else
      a -= R, m[a] = sub (m[a], w); }

  /* Done. */

  return ne;
}

unsigned short lg[65536];
unsigned short alg[65535];

#define POLY 0xbacd

void rsinit ()
{
  unsigned short r;
  int n;

  r = 1;
  for (n = 0; n < 65535; n++) {
    alg[n] = r;
    lg[r] = n;
    if (r & 0x8000)
      r = r << 1 ^ POLY;
    else
      r = r << 1;
  }
}
