/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.functions.IsFrameBlockInRange;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public class FrameIndexingSPInstruction
extends IndexingSPInstruction {
    protected FrameIndexingSPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
        super(in, rl, ru, cl, cu, out, aggtype, opcode, istr);
    }

    protected FrameIndexingSPInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) {
        super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        String opcode = this.getOpcode();
        long rl = ec.getScalarInput(this.rowLower).getLongValue();
        long ru = ec.getScalarInput(this.rowUpper).getLongValue();
        long cl = ec.getScalarInput(this.colLower).getLongValue();
        long cu = ec.getScalarInput(this.colUpper).getLongValue();
        IndexRange ixrange = new IndexRange(rl, ru, cl, cu);
        if (opcode.equalsIgnoreCase("rightIndex")) {
            DataCharacteristics mcIn = sec.getDataCharacteristics(this.input1.getName());
            DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
            mcOut.set(ru - rl + 1L, cu - cl + 1L, mcIn.getBlocksize(), mcIn.getBlocksize());
            FrameIndexingSPInstruction.checkValidOutputDimensions(mcOut);
            JavaPairRDD<Long, FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable(this.input1.getName());
            JavaPairRDD out = null;
            out = FrameIndexingSPInstruction.isPartitioningPreservingRightIndexing(mcIn, ixrange) ? in1.mapPartitionsToPair((PairFlatMapFunction)new SliceBlockPartitionFunction(ixrange, mcOut), true) : in1.filter((Function)new IsFrameBlockInRange(rl, ru, mcOut)).mapToPair((PairFunction)new SliceBlock(ixrange, mcOut));
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
            sec.getFrameObject(this.output.getName()).setSchema(sec.getFrameObject(this.input1.getName()).getSchema((int)cl, (int)cu));
        } else if (opcode.equalsIgnoreCase("leftIndex") || opcode.equalsIgnoreCase("mapLeftIndex")) {
            JavaPairRDD in1 = sec.getFrameBinaryBlockRDDHandleForVariable(this.input1.getName());
            PartitionedBroadcast<FrameBlock> broadcastIn2 = null;
            JavaPairRDD in2 = null;
            JavaPairRDD<Long, FrameBlock> out = null;
            DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
            DataCharacteristics mcLeft = ec.getDataCharacteristics(this.input1.getName());
            mcOut.set(mcLeft.getRows(), mcLeft.getCols(), mcLeft.getBlocksize(), mcLeft.getBlocksize());
            FrameIndexingSPInstruction.checkValidOutputDimensions(mcOut);
            DataCharacteristics mcRight = ec.getDataCharacteristics(this.input2.getName());
            if (!mcRight.dimsKnown()) {
                throw new DMLRuntimeException("The right input frame dimensions are not specified for FrameIndexingSPInstruction");
            }
            if (ru - rl + 1L != mcRight.getRows() || cu - cl + 1L != mcRight.getCols()) {
                throw new DMLRuntimeException("Invalid index range of leftindexing: [" + rl + ":" + ru + "," + cl + ":" + cu + "] vs [" + mcRight.getRows() + "x" + mcRight.getCols() + "].");
            }
            if (opcode.equalsIgnoreCase("mapLeftIndex")) {
                broadcastIn2 = sec.getBroadcastForFrameVariable(this.input2.getName());
                out = in1.mapPartitionsToPair((PairFlatMapFunction)new LeftIndexPartitionFunction(broadcastIn2, ixrange, mcOut), true);
            } else {
                in1 = in1.flatMapToPair((PairFlatMapFunction)new ZeroOutLHS(false, ixrange, mcLeft));
                in2 = sec.getFrameBinaryBlockRDDHandleForVariable(this.input2.getName()).flatMapToPair((PairFlatMapFunction)new SliceRHSForLeftIndexing(ixrange, mcLeft));
                out = FrameRDDAggregateUtils.mergeByKey((JavaPairRDD<Long, FrameBlock>)in1.union(in2));
            }
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
            if (broadcastIn2 != null) {
                sec.addLineageBroadcast(this.output.getName(), this.input2.getName());
            }
            if (in2 != null) {
                sec.addLineageRDD(this.output.getName(), this.input2.getName());
            }
        } else {
            throw new DMLRuntimeException("Invalid opcode (" + opcode + ") encountered in FrameIndexingSPInstruction.");
        }
    }

    private static boolean isPartitioningPreservingRightIndexing(DataCharacteristics mcIn, IndexRange ixrange) {
        return mcIn.dimsKnown() && ixrange.rowStart == 1L && ixrange.rowEnd == mcIn.getRows();
    }

    private static void checkValidOutputDimensions(DataCharacteristics mcOut) {
        if (!mcOut.dimsKnown()) {
            throw new DMLRuntimeException("FrameIndexingSPInstruction: The updated output dimensions are invalid: " + mcOut);
        }
    }

    private static class SliceBlockPartitionFunction
    implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Long, FrameBlock> {
        private static final long serialVersionUID = -1655390518299307588L;
        private IndexRange _ixrange;

        public SliceBlockPartitionFunction(IndexRange ixrange, DataCharacteristics mcOut) {
            this._ixrange = ixrange;
        }

        public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> arg0) throws Exception {
            return new SliceBlockPartitionIterator(arg0);
        }

        private class SliceBlockPartitionIterator
        extends LazyIterableIterator<Tuple2<Long, FrameBlock>> {
            public SliceBlockPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg) throws Exception {
                long rowindex = (Long)arg._1();
                FrameBlock in = (FrameBlock)arg._2();
                FrameBlock out = in.slice(0, in.getNumRows() - 1, (int)SliceBlockPartitionFunction.this._ixrange.colStart - 1, (int)SliceBlockPartitionFunction.this._ixrange.colEnd - 1, new FrameBlock());
                return new Tuple2((Object)rowindex, (Object)out);
            }
        }
    }

    private static class SliceBlock
    implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = -5270171193018691692L;
        private IndexRange _ixrange;

        public SliceBlock(IndexRange ixrange, DataCharacteristics mcOut) {
            this._ixrange = ixrange;
        }

        public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> kv) throws Exception {
            long rowindex = (Long)kv._1();
            FrameBlock in = (FrameBlock)kv._2();
            int rl = (int)(rowindex > this._ixrange.rowStart ? 0L : this._ixrange.rowStart - rowindex);
            int ru = (int)(this._ixrange.rowEnd - rowindex >= (long)in.getNumRows() ? (long)(in.getNumRows() - 1) : this._ixrange.rowEnd - rowindex);
            FrameBlock out = in.slice(rl, ru, (int)(this._ixrange.colStart - 1L), (int)(this._ixrange.colEnd - 1L), new FrameBlock());
            long rowindex2 = rowindex > this._ixrange.rowStart ? rowindex - this._ixrange.rowStart + 1L : 1L;
            return new Tuple2((Object)rowindex2, (Object)out);
        }
    }

    private static class LeftIndexPartitionFunction
    implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Long, FrameBlock> {
        private static final long serialVersionUID = -911940376947364915L;
        private PartitionedBroadcast<FrameBlock> _binput;
        private IndexRange _ixrange = null;

        public LeftIndexPartitionFunction(PartitionedBroadcast<FrameBlock> binput, IndexRange ixrange, DataCharacteristics mc) {
            this._binput = binput;
            this._ixrange = ixrange;
        }

        public LazyIterableIterator<Tuple2<Long, FrameBlock>> call(Iterator<Tuple2<Long, FrameBlock>> arg0) throws Exception {
            return new LeftIndexPartitionIterator(arg0);
        }

        private class LeftIndexPartitionIterator
        extends LazyIterableIterator<Tuple2<Long, FrameBlock>> {
            public LeftIndexPartitionIterator(Iterator<Tuple2<Long, FrameBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg) throws Exception {
                int iNumRowsInBlock = ((FrameBlock)arg._2).getNumRows();
                int iNumCols = ((FrameBlock)arg._2).getNumColumns();
                if (!UtilFunctions.isInFrameBlockRange((Long)arg._1(), iNumRowsInBlock, LeftIndexPartitionFunction.this._ixrange)) {
                    return arg;
                }
                long lhs_rl = Math.max(LeftIndexPartitionFunction.this._ixrange.rowStart, (Long)arg._1);
                long lhs_ru = Math.min(LeftIndexPartitionFunction.this._ixrange.rowEnd, (Long)arg._1 + (long)iNumRowsInBlock - 1L);
                long lhs_cl = Math.max(LeftIndexPartitionFunction.this._ixrange.colStart, 1L);
                long lhs_cu = Math.min(LeftIndexPartitionFunction.this._ixrange.colEnd, (long)iNumCols);
                long rhs_rl = lhs_rl - LeftIndexPartitionFunction.this._ixrange.rowStart + 1L;
                long rhs_ru = rhs_rl + (lhs_ru - lhs_rl);
                long rhs_cl = lhs_cl - LeftIndexPartitionFunction.this._ixrange.colStart + 1L;
                long rhs_cu = rhs_cl + (lhs_cu - lhs_cl);
                int lhs_lrl = (int)(lhs_rl - (Long)arg._1);
                int lhs_lru = (int)(lhs_ru - (Long)arg._1);
                int lhs_lcl = (int)lhs_cl - 1;
                int lhs_lcu = (int)lhs_cu - 1;
                FrameBlock ret = (FrameBlock)arg._2;
                int blen = 1000;
                long rhs_rl_pb = rhs_rl;
                long rhs_ru_pb = Math.min(rhs_ru, ((rhs_rl - 1L) / (long)blen + 1L) * (long)blen);
                while (rhs_rl_pb <= rhs_ru_pb) {
                    FrameBlock slicedRHSMatBlock = LeftIndexPartitionFunction.this._binput.slice(rhs_rl_pb, rhs_ru_pb, rhs_cl, rhs_cu, new FrameBlock());
                    int lhs_lrl_pb = (int)((long)lhs_lrl + (rhs_rl_pb - rhs_rl));
                    int lhs_lru_pb = (int)((long)lhs_lru + (rhs_ru_pb - rhs_ru));
                    ret = ret.leftIndexingOperations(slicedRHSMatBlock, lhs_lrl_pb, lhs_lru_pb, lhs_lcl, lhs_lcu, new FrameBlock());
                    rhs_rl_pb = rhs_ru_pb + 1L;
                    rhs_ru_pb = Math.min(rhs_ru, rhs_ru_pb + (long)blen);
                }
                return new Tuple2((Object)((Long)arg._1), (Object)ret);
            }
        }
    }

    private static class ZeroOutLHS
    implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = -2672267231152496854L;
        private boolean _complement = false;
        private IndexRange _ixrange = null;
        private int _blen = -1;
        private long _rlen = -1L;

        public ZeroOutLHS(boolean complement, IndexRange range, DataCharacteristics mcLeft) {
            this._complement = complement;
            this._ixrange = range;
            this._blen = OptimizerUtils.getDefaultFrameSize();
            this._blen = (int)mcLeft.getCols();
            this._rlen = mcLeft.getRows();
        }

        public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> kv) throws Exception {
            ArrayList<Pair<Long, FrameBlock>> out = new ArrayList<Pair<Long, FrameBlock>>();
            IndexRange curBlockRange = new IndexRange(this._ixrange.rowStart, this._ixrange.rowEnd, this._ixrange.colStart, this._ixrange.colEnd);
            long lGblStartRow = ((Long)kv._1 - 1L) / (long)this._blen * (long)this._blen + 1L;
            FrameBlock zeroBlk = null;
            int iMaxRowsToCopy = 0;
            int iRowStartDest = UtilFunctions.computeCellInBlock((Long)kv._1, this._blen);
            int iRowStartSrc = 0;
            while (iRowStartSrc < ((FrameBlock)kv._2).getNumRows()) {
                IndexRange range = UtilFunctions.getSelectedRangeForZeroOut(new Pair<Long, FrameBlock>((Long)kv._1, (FrameBlock)kv._2), this._blen, curBlockRange, lGblStartRow - 1L, lGblStartRow);
                if (range.rowStart == -1L && range.rowEnd == -1L && range.colStart == -1L && range.colEnd == -1L) {
                    throw new Exception("Error while getting range for zero-out");
                }
                int iMaxRows = (int)Math.min((long)this._blen, this._rlen - lGblStartRow + 1L);
                iMaxRowsToCopy = Math.min(iMaxRows, ((FrameBlock)kv._2).getNumRows() - iRowStartSrc);
                iMaxRowsToCopy = Math.min(iMaxRowsToCopy, iMaxRows - iRowStartDest);
                zeroBlk = ((FrameBlock)kv._2).zeroOutOperations(new FrameBlock(), range, this._complement, iRowStartSrc, iRowStartDest, iMaxRows, iMaxRowsToCopy);
                out.add(new Pair<Long, FrameBlock>(lGblStartRow, zeroBlk));
                curBlockRange.rowStart = lGblStartRow + (long)this._blen;
                iRowStartDest = UtilFunctions.computeCellInBlock(iRowStartDest + iMaxRowsToCopy + 1, this._blen);
                iRowStartSrc += iMaxRowsToCopy;
                lGblStartRow += (long)this._blen;
            }
            return SparkUtils.fromIndexedFrameBlock(out).iterator();
        }
    }

    private static class SliceRHSForLeftIndexing
    implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
        private static final long serialVersionUID = 5724800998701216440L;
        private IndexRange _ixrange = null;
        private int _blen = -1;
        private long _rlen = -1L;
        private long _clen = -1L;

        public SliceRHSForLeftIndexing(IndexRange ixrange, DataCharacteristics mcLeft) {
            this._ixrange = ixrange;
            this._rlen = mcLeft.getRows();
            this._clen = mcLeft.getCols();
            this._blen = (int)Math.min((long)OptimizerUtils.getDefaultFrameSize(), this._rlen);
            this._blen = (int)mcLeft.getCols();
        }

        public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> rightKV) throws Exception {
            Pair<Long, FrameBlock> in = SparkUtils.toIndexedFrameBlock(rightKV);
            ArrayList<Pair<Long, FrameBlock>> out = new ArrayList<Pair<Long, FrameBlock>>();
            OperationsOnMatrixValues.performShift(in, this._ixrange, this._blen, this._rlen, this._clen, out);
            return SparkUtils.fromIndexedFrameBlock(out).iterator();
        }
    }
}

