//===================================================================
// rankdist.cpp
//
// Version 1.1
//
// Written by:
//   Brent Worden
//   WordenWare
//   email:  Brent@Worden.org
//
// Copyright (c) 1998-1999 WordenWare
//
// Created:  August 28, 1998
// Revised:  April 10, 1999
//===================================================================

#include <cmath>

#include "normdist.h"
#include "rankdist.h"
#include "vector.hpp"

NUM_BEGIN

// Function needed for Ansary-Bradley
int start1(int n, double *f)
{
    int i;
    int ret = n/2 + 1;
    
    for (i = 0; i < ret; ++i) {
        f[i] = 2.0;
    }
    if(n % 2 == 0) {
        f[ret-1] = 1.0;
    }
    
    return ret;
}

// Function needed for Ansary-Bradley
int start2(int n, double *f)
{
    double a, b;
    int i, j, nu, lt1, ndo;
    int ret;
    
    nu = n - n % 2;
    ndo = (lt1 = (ret = nu + 1) + 1) / 2;
    a = 1.0;
    b = 3.0;
    for(i = 0, j = nu; i < ndo; ++i, --j) {
        f[i] = a;
        f[j] = a;
        a += b;
        b = 4.0 - b;
    }
    
    if(nu != n){
        nu = ndo + 1;
        for (i = nu; i < ret; ++i) {
            f[i] += 2.0;
        }
        f[lt1-1] = 2.0;
        ret = lt1;
    }
    
    return ret;
}

// Function needed for Ansary-Bradley
int frqadd(double *f1, int l1in, double *f2, int l2, int *nstart)
{
    int i1, i2, nxt, ret;
    
    for (i1 = *nstart-1, i2 = 0; i1 < l1in; ++i1, ++i2) {
        f1[i1] += 2.0 * f2[i2];
    }
    nxt = l1in + 1;
    ret = l2 + *nstart - 1;
    for(i1 = nxt-1; i1 < ret; ++i1, ++i2){
        f1[i1] = 2.0 * f2[i2];
    }
    ++(*nstart);
    
    return ret;
}

int imply(double *f1, int l1in, int l1out, double *f2, int noff)
{
    double diff, sum;
    int j2min, i2, j1, j2, i1, ndo, ret;
    
    i2 = 1 - noff - 1;
    j1 = l1out;
    j2 = l1out - noff;
    ret = j2;
    j2min = (j2 + 1) / 2;
    ndo = (l1out + 1) / 2;
    
    for(i1 = 0; i1 < ndo; ++i1, ++i2, --j1){
        if(i2 >= 0){
            sum = f1[i1] + f2[i2];
            f1[i1] = sum;
        } else {
            sum = f1[i1];
        }
        if(j2 >= j2min){
            if(j1 > l1in){
                diff = sum;
            } else {
                diff = sum - f1[j1-1];
            }
            f2[i1] = diff;
            f2[j2-1] = diff;
            --j2;
        }
        f1[j1-1] = sum;
    }
    
    return ret;
}

// Function needed for Ansary-Bradley
void abmass(int m, int n, double *a1, int l1)
{
    int i1, i2, lres, mnow, symm, l1out, l2out, i, j, mm, nn;
    int nc, mm1, ln1, nm1, nm2, ln3, ln2, n2b1, n2b2, ndo, astart;
    double ai, *a2, *a3;
    
    a2 = new double[l1+1];
    a3 = new double[l1+1];
    
    mm = ((m < n) ? m : n);
    if (mm < 0) {
        delete [] a2;
        delete [] a3;
        return;
    }
    i1 = (m + 1) / 2;
    i2 = m / 2 + 1;
    astart = i1 * i2;
    nn = ((m > n) ? m : n);
    
    lres = mm * nn / 2 + 1;
    if(l1 < lres){
        delete [] a2;
        delete [] a3;
        return;
    }
    symm = (mm + nn) % 2 == 0;
    
    mm1 = mm - 1;
    if(mm > 2){
        nm1 = nn - 1;
        nm2 = nn - 2;
        mnow = 3;
        nc = 3;
        if(nn % 2 == 1){
            n2b1 = 2;
            n2b2 = 3;
            ln1 = start1(nn, a1);
            ln2 = start2(nm1, a2);
            do {
                l1out = frqadd(a1, ln1, a2, ln2, &n2b1);
                ln1 += nn;
                ln3 = imply(a1, l1out, ln1, a3, nc);
                ++nc;
                if(mnow != mm){
                    ++mnow;
                    l2out = frqadd(a2, ln2, a3, ln3, &n2b2);
                    ln2 += nm1;
                    j = imply(a2, l2out, ln2, a3, nc);
                    ++nc;
                    if(mnow != mm){
                        ++mnow;
                    }
                }
            } while(mnow != mm);
        } else {
            n2b1 = 3;
            n2b2 = 2;
            ln1 = start2(nn, a1);
            ln3 = start2(nm2, a3);
            ln2 = start1(nm1, a2);
            do {
                l2out = frqadd(a2, ln2, a3, ln3, &n2b2);
                ln2 += nm1;
                j = imply(a2, l2out, ln2, a3, nc);
                ++nc;
                if(mnow != mm){
                    ++mnow;
                    l1out = frqadd(a1, ln1, a2, ln2, &n2b1);
                    ln1 += nn;
                    ln3 = imply(a1, l1out, ln1, a3, nc);
                    ++nc;
                    if(mnow != mm){
                        ++mnow;
                    }
                }
            } while(mnow != mm);
        }
        if(symm){
            delete [] a2;
            delete [] a3;
            return;
        }
        
        for(i = (mm+3)/2, j = 1; i <= lres; ++i, ++j) {
            if(i > ln1) {
                a1[i-1] = a2[j-1];
            } else {
                a1[i-1] += a2[j-1];
            }
        }
        
        if(n < m){
            delete [] a2;
            delete [] a3;
            return;
        }
    } else {
        if(mm1 < 0){
            a1[1-1] = 1.0;
            delete [] a2;
            delete [] a3;
            return;
        } else if(mm1 == 0){
            ln1 = start1(nn, a1);
        } else {
            ln1 = start2(nn, a1);
        }
        if(symm || n > m){
            delete [] a2;
            delete [] a3;
            return;
        }
    }
    ndo = lres / 2;
    for(i = 1, j = lres; i <= ndo; ++i, --j) {
        ai = a1[i-1];
        a1[i-1] = a1[j-1];
        a1[j-1] = ai;
    }
    delete [] a2;
    delete [] a3;
}

NUMERICS_EXPORT double ansbradp(double x, int m, int n)
{
    int astart;
    int i, nrows;
    double sum;
    int test = m, other = n;
    double *a1;
    double ret = 0.0;
    int l1 = (m*n)/2+1;
    
    a1 = new double[l1+1];
    
    astart = m/2 * (m/2 + 1);
    if(m % 2 == 1){
        astart += (m+1)/2;
    }
    
    abmass(m, n, a1, l1);
    
    nrows = m * n / 2 + 1;
    sum = 0.0;
    for (i = 0; i < nrows; ++i) {
        sum += a1[i];
        a1[i] = sum;
    }
    for (i = 0; i <= x-astart; ++i) {
        a1[i] /= sum;
    }
    
    ret = a1[(int)(x-astart)];
    
    delete [] a1;
    
    return ret;
}

// Function needed for Wilcoxon
int wilcox1(int x, int n, int k)
{
    int c, h, m, w, v, j, y;
    
    if(k == 1){
        if(x < n) return x;
        return n;
    }
    c = 1;
    h = k-1;
    m = n-1;
    w = 0;
    for(j = 1; j <= h; j++) c = c*(m-j+1)/j;
    v = h*(m-h)+h*(h+1)/2;
    y=x-h-1;
    while(y >= v){
        w += c;
        c = (c*(m-h))/m;
        m--;
        v-=h;
        y=y-h-1;
    }
    do {
        w += wilcox1(y, m, h);
        m--;
        y = y-h-1;
    } while(2*y >= h*(h+1));
    
    return w;
}

NUMERICS_EXPORT double wilcoxonp(int x, int n)
{
    int c, k, u, v, w, xx;
    bool a;
    double m=n*(n+1.0)/4.0, p, cv;
    
    if(n <= 25){
        if(x < 0) return 0.0;
        m = n*(n+1)/2;
        if(x >= m) return 1.0;
        if(2*x > m){
            xx = (int)(m)-x-1;
            a = false;
        } else {
            xx = x;
            a = true;
        }
        c = 1; k = 1; u = 1; v = n; w = 1;
        while(xx >= v){
            c = (c*(n-k+1))/k;
            w += c;
            k++;
            u += k;
            v += n+1-k;
        }
        do {
            w += wilcox1(xx, n, k);
            k++;
            u += k;
        } while(xx >= u);
        if(a) p = ldexp((double)w, -(int)(n));
        else p = 1.0 - ldexp((double)w, -(int)(n));
    } else {
        cv = n*(n+1.0)*(2.0*n+1.0)/24.0;
        p = normalp(x+.5, m, cv);
    }
    
    return p;
}

NUMERICS_EXPORT void wilcoxonv(double p, int n, int *w0, int *w1)
{
    double m, v, wd;
    int w;
    
    m = (double)(n*(n+1))/4.0;
    v = (double)(n*(n+1)*(2.0*n+1))/24.0;
    wd = normalv(p, m, v);
    if(p < 0.0 || p > 1.0 || n < 1){
        *w0 = *w1 = -1;
    } else if(n > 25){
        if(wd < 1.0) wd = 1.0;
        else if(wd > 2.0*m) wd = 2.0*m-.5;
        *w0 = (int)floor(wd);
        *w1 = (int)ceil(wd);
    } else {
        w = (int)wd;
        while(wilcoxonp(w, n) > p) w--;
        while(wilcoxonp(w, n) <= p) w++;
        if(w <= 1){
            *w0 = w;
        } else *w0 = w-1;
        *w1 = w;
    }
}

NUMERICS_EXPORT double wilmannwhitp(int w, int m, int n)
{
    int u = w - m * (m + 1) / 2;
    int minmn = ((m < n) ? m : n);
    int mn1 = m*n+1, i;
    int maxmn = ((m > n) ? m : n), n1 = maxmn + 1;
    double ret = 0.0;
    
    if(u < 0){
        return ret;
    }
    
    if(minmn < 1){
        return ret;
    }
    
	Vector<double> freq(mn1);
    
    for(i = 1; i <= n1; ++i){
        freq[i-1] = 1.0;
    }
    
    if(minmn != 1){
        double sum;
        int in = maxmn, l, k, j;
		Vector<double> work((mn1+1)/2+minmn);
        
        n1 = n1 + 1;
        for(i = n1; i <= mn1; ++i){
            freq[i-1] = 0.0;
        }
        
        work[1-1] = 0.0;
        for(i = 2; i <= minmn; ++i){
            work[i-1] = 0.0;
            in += maxmn;
            n1 = in + 2;
            l = 1 + in/2;
            k = i;
            for(j = 1; j <= l; ++j){
                ++k;
                --n1;
                sum = freq[j-1] + work[j-1];
                freq[j-1] = sum;
                work[k-1] = sum - freq[n1-1];
                freq[n1-1] = sum;
            }
        }
        sum = 0.0;
        for(i = 1; i <= mn1; ++i){
            sum += freq[i-1];
            freq[i-1] = sum;
        }
        for(i = 1; i <= u+1; ++i){
            freq[i-1] /= sum;
        }
        ret = freq[u];
    }
    
    return ret;
}

NUMERICS_EXPORT void wilmannwhitv(double p, int m, int n, int *w0, int *w1)
{
    double mn, v, wd;
    int w;
    
    if(p < 0.0 || p > 1.0 || n < 1){
        *w0 = *w1 = -1;
    } else {
        mn = (double)(m*(m+n+1))/2.0;
        v = (double)(m*n*(m+n+1))/12.0;
        wd = normalv(p, mn, v);
        w = (int)wd;
        while(wilmannwhitp(w, m, n) > p) w--;
        while(wilmannwhitp(w, m, n) <= p) w++;
        if(w <= m*(m+1)/2){
            *w0 = m*(m+1)/2;
        } else {
            *w0 = w-1;
        }
        *w1 = w;
    }
}

NUM_END

//===================================================================
// Revision History
//
// Version 1.0 - 08/28/1998 - New.
// Version 1.1 - 04/10/1999 - Added Numerics namespace.
//===================================================================
