/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import java.util.HashMap;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.Callable;
import org.apache.commons.lang3.tuple.MutableTriple;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.StringArray;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;

public class ColumnEncoderBin
extends ColumnEncoder {
    public static final String MIN_PREFIX = "min";
    public static final String MAX_PREFIX = "max";
    public static final String NBINS_PREFIX = "nbins";
    private static final long serialVersionUID = 1917445005206076078L;
    public static final double SAMPLE_FRACTION = 0.1;
    public static final int MINIMUM_SAMPLE_SIZE = 1000;
    protected int _numBin = -1;
    private BinMethod _binMethod = BinMethod.EQUI_WIDTH;
    private double[] _binMins = null;
    private double[] _binMaxs = null;
    private double _colMins = -1.0;
    private double _colMaxs = -1.0;

    public ColumnEncoderBin() {
        super(-1);
    }

    public ColumnEncoderBin(int colID, int numBin, BinMethod binMethod) {
        super(colID);
        this._numBin = numBin;
        this._binMethod = binMethod;
    }

    public ColumnEncoderBin(int colID, int numBin, double[] binMins, double[] binMaxs) {
        super(colID);
        this._numBin = numBin;
        this._binMins = binMins;
        this._binMaxs = binMaxs;
    }

    public int getNumBin() {
        return this._numBin;
    }

    public double getColMins() {
        return this._colMins;
    }

    public double getColMaxs() {
        return this._colMaxs;
    }

    public double[] getBinMins() {
        return this._binMins;
    }

    public double[] getBinMaxs() {
        return this._binMaxs;
    }

    public BinMethod getBinMethod() {
        return this._binMethod;
    }

    public void setBinMethod(String method) {
        if (method.equalsIgnoreCase(BinMethod.EQUI_WIDTH.toString())) {
            this._binMethod = BinMethod.EQUI_WIDTH;
        } else if (method.equalsIgnoreCase(BinMethod.EQUI_HEIGHT.toString())) {
            this._binMethod = BinMethod.EQUI_HEIGHT;
        } else if (method.equalsIgnoreCase(BinMethod.EQUI_HEIGHT_APPROX.toString())) {
            this._binMethod = BinMethod.EQUI_HEIGHT_APPROX;
        } else {
            throw new RuntimeException(method + " is invalid");
        }
    }

    @Override
    public void build(CacheBlock<?> in) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (!this.isApplicable()) {
            return;
        }
        if (this._binMethod == BinMethod.EQUI_WIDTH) {
            double[] pairMinMax = ColumnEncoderBin.getMinMaxOfCol(in, this._colID, 0, -1);
            this.computeBins(pairMinMax[0], pairMinMax[1]);
        } else if (this._binMethod == BinMethod.EQUI_HEIGHT) {
            double[] sortedCol = ColumnEncoderBin.prepareDataForEqualHeightBins(in, this._colID, 0, -1);
            this.computeEqualHeightBins(sortedCol, false);
        } else if (this._binMethod == BinMethod.EQUI_HEIGHT_APPROX) {
            double[] vals = ColumnEncoderBin.sampleDoubleColumn(in, this._colID, 0.1, 1000);
            Arrays.sort(vals);
            this.computeEqualHeightBins(vals, false);
        }
        if (DMLScript.STATISTICS) {
            TransformStatistics.incBinningBuildTime(System.nanoTime() - t0);
        }
    }

    @Override
    public void build(CacheBlock<?> in, double[] equiHeightMaxs) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (!this.isApplicable()) {
            return;
        }
        if (this._binMethod == BinMethod.EQUI_WIDTH) {
            double[] pairMinMax = ColumnEncoderBin.getMinMaxOfCol(in, this._colID, 0, -1);
            this.computeBins(pairMinMax[0], pairMinMax[1]);
        } else if (this._binMethod == BinMethod.EQUI_HEIGHT || this._binMethod == BinMethod.EQUI_HEIGHT_APPROX) {
            this.computeEqualHeightBins(equiHeightMaxs, true);
        }
        if (DMLScript.STATISTICS) {
            TransformStatistics.incBinningBuildTime(System.nanoTime() - t0);
        }
    }

    @Override
    protected double getCode(CacheBlock<?> in, int row) {
        if (this._binMins.length == 0 || this._binMaxs.length == 0) {
            LOG.warn((Object)"ColumnEncoderBin: applyValue without bucket boundaries, assign 1");
            return 1.0;
        }
        double inVal = in.getDoubleNaN(row, this._colID - 1);
        return this.getCodeIndex(inVal);
    }

    @Override
    protected final double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) {
        double[] codes;
        int endLength = endInd - startInd;
        double[] dArray = codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength];
        if (this._binMins == null || this._binMins.length == 0 || this._binMaxs.length == 0) {
            LOG.warn((Object)"ColumnEncoderBin: applyValue without bucket boundaries, assign 1");
            Arrays.fill(codes, 0, endLength, 1.0);
            return codes;
        }
        if (in instanceof FrameBlock) {
            this.getCodeColFrame((FrameBlock)in, startInd, endInd, codes);
        } else {
            for (int i = startInd; i < endInd; ++i) {
                double inVal = in.getDoubleNaN(i, this._colID - 1);
                codes[i - startInd] = this.getCodeIndex(inVal);
            }
        }
        return codes;
    }

    protected final void getCodeColFrame(FrameBlock in, int startInd, int endInd, double[] codes) {
        Array<?> c = in.getColumn(this._colID - 1);
        double mi = this._binMins[0];
        double mx = this._binMaxs[this._binMaxs.length - 1];
        if (!(c instanceof StringArray) && !c.containsNull()) {
            for (int i = startInd; i < endInd; ++i) {
                codes[i - startInd] = this.getCodeIndex(c.getAsDouble(i), mi, mx);
            }
        } else {
            for (int i = startInd; i < endInd; ++i) {
                codes[i - startInd] = this.getCodeIndex(c.getAsNaNDouble(i), mi, mx);
            }
        }
    }

    protected final double getCodeIndex(double inVal) {
        return this.getCodeIndex(inVal, this._binMins[0], this._binMaxs[this._binMaxs.length - 1]);
    }

    protected final double getCodeIndex(double inVal, double min, double max) {
        if (Double.isNaN(inVal)) {
            return Double.NaN;
        }
        if (this._binMethod == BinMethod.EQUI_WIDTH) {
            return this.getEqWidth(inVal, min, max);
        }
        return this.getCodeIndexEQHeight(inVal);
    }

    private final double getEqWidth(double inVal, double min, double max) {
        if (max == min) {
            return 1.0;
        }
        if (this._numBin <= 0) {
            throw new RuntimeException("Invalid num bins");
        }
        int code = (int)Math.ceil((inVal - min) / (max - min) * (double)this._numBin);
        return code > this._numBin ? (double)this._numBin : (code < 1 ? 1.0 : (double)code);
    }

    private final double getCodeIndexEQHeight(double inVal) {
        if (this._binMaxs.length <= 10) {
            return this.getCodeIndexEQHeightSmall(inVal);
        }
        return this.getCodeIndexEQHeightNormal(inVal);
    }

    private final double getCodeIndexEQHeightSmall(double inVal) {
        for (int i = 0; i < this._binMaxs.length - 1; ++i) {
            if (!(inVal <= this._binMaxs[i])) continue;
            return i + 1;
        }
        return this._binMaxs.length;
    }

    private final double getCodeIndexEQHeightNormal(double inVal) {
        int ix = Arrays.binarySearch(this._binMaxs, inVal);
        if (ix < 0) {
            return Math.min(Math.abs(ix + 1) + 1, this._binMaxs.length);
        }
        if (ix == 0) {
            return 1.0;
        }
        return Math.min(ix + 1, this._binMaxs.length);
    }

    @Override
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.BIN;
    }

    private static double[] getMinMaxOfCol(CacheBlock<?> in, int colID, int startRow, int blockSize) {
        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        int end = UtilFunctions.getEndIndex(in.getNumRows(), startRow, blockSize);
        for (int i = startRow; i < end; ++i) {
            double inVal = in.getDoubleNaN(i, colID - 1);
            if (Double.isNaN(inVal)) continue;
            min = Math.min(min, inVal);
            max = Math.max(max, inVal);
        }
        return new double[]{min, max};
    }

    private static double[] prepareDataForEqualHeightBins(CacheBlock<?> in, int colID, int startRow, int blockSize) {
        double[] vals = ColumnEncoderBin.extractDoubleColumn(in, colID, startRow, blockSize);
        Arrays.sort(vals);
        return vals;
    }

    private static double[] extractDoubleColumn(CacheBlock<?> in, int colID, int startRow, int blockSize) {
        int endRow = UtilFunctions.getEndIndex(in.getNumRows(), startRow, blockSize);
        double[] vals = new double[endRow - startRow];
        int cid = colID - 1;
        if (in instanceof FrameBlock) {
            Array<?> a = ((FrameBlock)in).getColumn(cid);
            for (int i = startRow; i < endRow; ++i) {
                double inVal = a.getAsNaNDouble(i);
                if (Double.isNaN(inVal)) continue;
                vals[i - startRow] = inVal;
            }
        } else {
            for (int i = startRow; i < endRow; ++i) {
                double inVal = in.getDoubleNaN(i, cid);
                if (Double.isNaN(inVal)) continue;
                vals[i - startRow] = inVal;
            }
        }
        return vals;
    }

    private static double[] sampleDoubleColumn(CacheBlock<?> in, int colID, double sampleFraction, int minimum_sample_size) {
        int nRow = in.getNumRows();
        int elm = (int)Math.min((double)nRow, Math.max((double)minimum_sample_size, Math.ceil((double)nRow * sampleFraction)));
        double[] vals = new double[elm];
        Array<?> a = ((FrameBlock)in).getColumn(colID - 1);
        int s = DMLScript.SEED;
        Random r = s == -1 ? new Random() : new Random(s);
        for (int i = 0; i < elm; ++i) {
            double inVal;
            vals[i] = inVal = a.getAsNaNDouble(r.nextInt(nRow));
        }
        return vals;
    }

    @Override
    public Callable<Object> getBuildTask(CacheBlock<?> in) {
        return new ColumnBinBuildTask(this, in);
    }

    @Override
    public Callable<Object> getPartialBuildTask(CacheBlock<?> in, int startRow, int blockSize, HashMap<Integer, Object> ret) {
        return new BinPartialBuildTask(in, this._colID, startRow, blockSize, this._binMethod, ret);
    }

    @Override
    public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> ret) {
        return new BinMergePartialBuildTask(this, ret);
    }

    public void computeBins(double min, double max) {
        if (this._binMins == null || this._binMaxs == null) {
            this._binMins = new double[this._numBin];
            this._binMaxs = new double[this._numBin];
        }
        for (int i = 0; i < this._numBin; ++i) {
            this._binMins[i] = min + (double)i * (max - min) / (double)this._numBin;
            this._binMaxs[i] = min + (double)(i + 1) * (max - min) / (double)this._numBin;
        }
    }

    private void computeEqualHeightBins(double[] sortedCol, boolean doNotTakeQuantiles) {
        if (this._binMins == null || this._binMaxs == null) {
            this._binMins = new double[this._numBin];
            this._binMaxs = new double[this._numBin];
        }
        if (!doNotTakeQuantiles) {
            int n = sortedCol.length;
            for (int i = 0; i < this._numBin; ++i) {
                double pos = (double)n * ((double)i + 1.0) / (double)this._numBin;
                this._binMaxs[i] = pos % 1.0 == 0.0 ? sortedCol[(int)pos - 1] : sortedCol[(int)Math.floor(pos)];
            }
            this._binMaxs[this._numBin - 1] = sortedCol[n - 1];
        } else {
            System.arraycopy(sortedCol, 1, this._binMaxs, 0, this._numBin);
        }
        this._binMins[0] = sortedCol[0];
        System.arraycopy(this._binMaxs, 0, this._binMins, 1, this._numBin - 1);
    }

    @Override
    public void prepareBuildPartial() {
        this._colMins = -1.0;
        this._colMaxs = -1.0;
    }

    @Override
    public void buildPartial(FrameBlock in) {
        if (!this.isApplicable()) {
            return;
        }
        double[] pairMinMax = ColumnEncoderBin.getMinMaxOfCol(in, this._colID, 0, -1);
        this._colMins = pairMinMax[0];
        this._colMaxs = pairMinMax[1];
    }

    @Override
    protected ColumnEncoder.ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock<?> in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new BinSparseApplyTask(this, in, out, outputCol);
    }

    @Override
    public void mergeAt(ColumnEncoder other) {
        if (other instanceof ColumnEncoderBin) {
            ColumnEncoderBin otherBin = (ColumnEncoderBin)other;
            assert (other._colID == this._colID);
            MutableTriple entry = new MutableTriple((Object)this._numBin, (Object)this._binMins[0], (Object)this._binMaxs[this._binMaxs.length - 1]);
            entry.middle = Math.min((Double)entry.middle, otherBin._binMins[0]);
            entry.right = Math.max((Double)entry.right, otherBin._binMaxs[otherBin._binMaxs.length - 1]);
            this._numBin = (Integer)entry.left;
            this._binMins = new double[this._numBin];
            this._binMaxs = new double[this._numBin];
            double min = (Double)entry.middle;
            double max = (Double)entry.right;
            for (int j = 0; j < this._numBin; ++j) {
                this._binMins[j] = min + (double)j * (max - min) / (double)this._numBin;
                this._binMaxs[j] = min + (double)(j + 1) * (max - min) / (double)this._numBin;
            }
            return;
        }
        super.mergeAt(other);
    }

    @Override
    public void allocateMetaData(FrameBlock meta) {
        meta.ensureAllocatedColumns(this._binMaxs.length);
    }

    @Override
    public FrameBlock getMetaData(FrameBlock meta) {
        meta.ensureAllocatedColumns(this._binMaxs.length);
        meta.getColumnMetadata(this._colID - 1).setNumDistinct(this._numBin);
        for (int i = 0; i < this._binMaxs.length; ++i) {
            String sb = this._binMins[i] + "\u00b7" + this._binMaxs[i];
            meta.set(i, this._colID - 1, sb);
        }
        return meta;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        if (meta == null || this._binMaxs != null || meta.getColumnMetadata()[this._colID - 1].isDefault()) {
            return;
        }
        int nbins = (int)meta.getColumnMetadata()[this._colID - 1].getNumDistinct();
        this._binMins = new double[nbins];
        this._binMaxs = new double[nbins];
        for (int i = 0; i < nbins; ++i) {
            String[] tmp = meta.get(i, this._colID - 1).toString().split("\u00b7");
            this._binMins[i] = Double.parseDouble(tmp[0]);
            this._binMaxs[i] = Double.parseDouble(tmp[1]);
        }
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        super.writeExternal(out);
        out.writeInt(this._numBin);
        out.writeUTF(this._binMethod.toString());
        out.writeBoolean(this._binMaxs != null);
        if (this._binMaxs != null) {
            for (int j = 0; j < this._binMaxs.length; ++j) {
                out.writeDouble(this._binMaxs[j]);
                out.writeDouble(this._binMins[j]);
            }
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        super.readExternal(in);
        this._numBin = in.readInt();
        this.setBinMethod(in.readUTF());
        boolean minmax = in.readBoolean();
        this._binMaxs = minmax ? new double[this._numBin] : null;
        double[] dArray = this._binMins = minmax ? new double[this._numBin] : null;
        if (!minmax) {
            return;
        }
        for (int j = 0; j < this._binMaxs.length; ++j) {
            this._binMaxs[j] = in.readDouble();
            this._binMins[j] = in.readDouble();
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.getClass().getSimpleName());
        sb.append(": ");
        sb.append(this._colID);
        sb.append(" --- Method: " + this._binMethod + " num Bin: " + this._numBin);
        sb.append("\n---- BinMin: " + Arrays.toString(this._binMins));
        sb.append("\n---- BinMax: " + Arrays.toString(this._binMaxs));
        return sb.toString();
    }

    private static class ColumnBinBuildTask
    implements Callable<Object> {
        private final ColumnEncoderBin _encoder;
        private final CacheBlock<?> _input;

        protected ColumnBinBuildTask(ColumnEncoderBin encoder, CacheBlock<?> input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() throws Exception {
            this._encoder.build(this._input);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class ArrayContainer
    implements Comparable<ArrayContainer> {
        double[] arr;
        int index;

        public ArrayContainer(double[] arr, int index) {
            this.arr = arr;
            this.index = index;
        }

        @Override
        public int compareTo(ArrayContainer o) {
            return this.arr[this.index] < o.arr[o.index] ? -1 : 1;
        }
    }

    private static class BinMergePartialBuildTask
    implements Callable<Object> {
        private final HashMap<Integer, ?> _partialMaps;
        private final ColumnEncoderBin _encoder;

        private BinMergePartialBuildTask(ColumnEncoderBin encoderBin, HashMap<Integer, ?> partialMaps) {
            this._partialMaps = partialMaps;
            this._encoder = encoderBin;
        }

        private double[] mergeKSortedArrays(double[][] arrs) {
            PriorityQueue<ArrayContainer> queue = new PriorityQueue<ArrayContainer>();
            int total = 0;
            for (double[] arr : arrs) {
                queue.add(new ArrayContainer(arr, 0));
                total += arr.length;
            }
            int m = 0;
            double[] result = new double[total];
            while (!queue.isEmpty()) {
                ArrayContainer ac = (ArrayContainer)queue.poll();
                result[m++] = ac.arr[ac.index];
                if (ac.index >= ac.arr.length - 1) continue;
                queue.add(new ArrayContainer(ac.arr, ac.index + 1));
            }
            return result;
        }

        @Override
        public Object call() throws Exception {
            long t0;
            long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (this._encoder.getBinMethod() == BinMethod.EQUI_WIDTH) {
                double min = Double.POSITIVE_INFINITY;
                double max = Double.NEGATIVE_INFINITY;
                for (Object minMax : this._partialMaps.values()) {
                    min = Math.min(min, ((double[])minMax)[0]);
                    max = Math.max(max, ((double[])minMax)[1]);
                }
                this._encoder.computeBins(min, max);
            }
            if (this._encoder.getBinMethod() == BinMethod.EQUI_HEIGHT) {
                double[][] allParts = new double[this._partialMaps.size()][];
                int i = 0;
                for (Object arr : this._partialMaps.values()) {
                    allParts[i++] = (double[])arr;
                }
                double[] sortedRes = this.mergeKSortedArrays(allParts);
                this._encoder.computeEqualHeightBins(sortedRes, false);
            }
            if (DMLScript.STATISTICS) {
                TransformStatistics.incBinningBuildTime(System.nanoTime() - t0);
            }
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class BinPartialBuildTask
    implements Callable<Object> {
        private final CacheBlock<?> _input;
        private final int _blockSize;
        private final int _startRow;
        private final int _colID;
        private final BinMethod _method;
        private final HashMap<Integer, Object> _partialData;

        protected BinPartialBuildTask(CacheBlock<?> input, int colID, int startRow, int blocksize, BinMethod method, HashMap<Integer, Object> partialData) {
            this._input = input;
            this._blockSize = blocksize;
            this._colID = colID;
            this._startRow = startRow;
            this._method = method;
            this._partialData = partialData;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public double[] call() throws Exception {
            long t0;
            long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (this._method == BinMethod.EQUI_WIDTH) {
                double[] minMax = ColumnEncoderBin.getMinMaxOfCol(this._input, this._colID, this._startRow, this._blockSize);
                HashMap<Integer, Object> hashMap = this._partialData;
                synchronized (hashMap) {
                    this._partialData.put(this._startRow, minMax);
                }
            }
            if (this._method == BinMethod.EQUI_HEIGHT || this._method == BinMethod.EQUI_HEIGHT_APPROX) {
                double[] sortedVals = ColumnEncoderBin.prepareDataForEqualHeightBins(this._input, this._colID, this._startRow, this._blockSize);
                HashMap<Integer, Object> hashMap = this._partialData;
                synchronized (hashMap) {
                    this._partialData.put(this._startRow, sortedVals);
                }
            }
            if (DMLScript.STATISTICS) {
                TransformStatistics.incBinningBuildTime(System.nanoTime() - t0);
            }
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<Start row: " + this._startRow + "; Block size: " + this._blockSize + ">";
        }
    }

    private static class BinSparseApplyTask
    extends ColumnEncoder.ColumnApplyTask<ColumnEncoderBin> {
        public BinSparseApplyTask(ColumnEncoderBin encoder, CacheBlock<?> input, MatrixBlock out, int outputCol, int startRow, int blk) {
            super(encoder, input, out, outputCol, startRow, blk);
        }

        private BinSparseApplyTask(ColumnEncoderBin encoder, CacheBlock<?> input, MatrixBlock out, int outputCol) {
            super(encoder, input, out, outputCol);
        }

        @Override
        public Object call() throws Exception {
            long t0;
            long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (this._out.getSparseBlock() == null) {
                return null;
            }
            ((ColumnEncoderBin)this._encoder).applySparse(this._input, this._out, this._outputCol, this._startRow, this._blk);
            if (DMLScript.STATISTICS) {
                TransformStatistics.incBinningApplyTime(System.nanoTime() - t0);
            }
            return null;
        }

        @Override
        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + ((ColumnEncoderBin)this._encoder)._colID + ">";
        }
    }

    public static enum BinMethod {
        INVALID,
        EQUI_WIDTH,
        EQUI_HEIGHT,
        EQUI_HEIGHT_APPROX;


        public String toString() {
            switch (this) {
                case EQUI_WIDTH: {
                    return "EQUI-WIDTH";
                }
                case EQUI_HEIGHT: {
                    return "EQUI-HEIGHT";
                }
                case EQUI_HEIGHT_APPROX: {
                    return "EQUI_HEIGHT_APPROX";
                }
            }
            throw new DMLRuntimeException("Invalid encoder type.");
        }
    }
}

