/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;

public class CtableFEDInstruction
extends ComputationFEDInstruction {
    private final CPOperand _outDim1;
    private final CPOperand _outDim2;

    private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, boolean ignoreZeros, String opcode, String istr) {
        super(FEDInstruction.FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
        this._outDim1 = new CPOperand(outputDim1, Types.ValueType.FP64, Types.DataType.SCALAR, dim1Literal);
        this._outDim2 = new CPOperand(outputDim2, Types.ValueType.FP64, Types.DataType.SCALAR, dim2Literal);
    }

    public static CtableFEDInstruction parseInstruction(String inst) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
        InstructionUtils.checkNumFields(parts, 7);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ctable")) {
            throw new DMLRuntimeException("Unexpected opcode in CtableFEDInstruction: " + inst);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        String[] dim1Fields = parts[4].split("\u00b7");
        String[] dim2Fields = parts[5].split("\u00b7");
        CPOperand out = new CPOperand(parts[6]);
        boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
        return new CtableFEDInstruction(in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), false, ignoreZeros, opcode, inst);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        boolean reversedWeights;
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        MatrixObject mo2 = ec.getMatrixObject(this.input2);
        boolean reversed = false;
        if (!mo1.isFederated() && mo2.isFederated()) {
            mo1 = ec.getMatrixObject(this.input2);
            mo2 = ec.getMatrixObject(this.input1);
            reversed = true;
        }
        Long[] dims1 = this.getOutputDimension(mo1, reversed ? this.input2 : this.input1, reversed ? this._outDim2 : this._outDim1, mo1.getFedMapping().getFederatedRanges());
        Long[] dims2 = this.getOutputDimension(mo2, reversed ? this.input1 : this.input2, reversed ? this._outDim1 : this._outDim2, mo1.getFedMapping().getFederatedRanges());
        CacheableData mo3 = this.input3 != null && this.input3.isMatrix() ? ec.getMatrixObject(this.input3) : null;
        boolean bl = reversedWeights = mo3 != null && mo3.isFederated() && !mo1.isFederated() && !mo2.isFederated();
        if (reversedWeights) {
            mo3 = mo1;
            mo1 = ec.getMatrixObject(this.input3);
        }
        long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
        boolean fedOutput = CtableFEDInstruction.isFedOutput(mo1.getFedMapping(), mo2);
        this.processRequest(ec, mo1, mo2, (MatrixObject)mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
    }

    private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3, boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
        FederatedRequest fr3;
        FederationMap fedMap = mo1.getFedMapping();
        FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
        FederatedRequest[] fr2 = null;
        if (mo3 != null && mo1.isFederated() && mo3.isFederated() && fedMap.isAligned(mo3.getFedMapping(), FederationMap.AlignType.FULL)) {
            fr3 = !reversed ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()}) : FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
        } else if (mo3 == null) {
            fr3 = !reversed ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{fedMap.getID(), fr1[0].getID()}) : FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{fr1[0].getID(), fedMap.getID()});
        } else {
            fr2 = fedMap.broadcastSliced(mo3, false);
            fr3 = !reversed && !reversedWeights ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fedMap.getID(), fr1[0].getID(), fr2[0].getID()}) : (reversed && !reversedWeights ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fr1[0].getID(), fedMap.getID(), fr2[0].getID()}) : FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fr1[0].getID(), fr2[0].getID(), fedMap.getID()}));
        }
        if (fedOutput) {
            if (fr2 != null) {
                fedMap.execute(this.getTID(), true, fr1, fr2, new FederatedRequest[]{fr3});
            } else {
                fedMap.execute(this.getTID(), true, fr1, new FederatedRequest[]{fr3});
            }
            MatrixObject out = ec.getMatrixObject(this.output);
            FederationMap newFedMap = CtableFEDInstruction.modifyFedRanges(fedMap.copyWithNewID(fr3.getID()), staticDim, dims2, reversed);
            CtableFEDInstruction.setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
        } else {
            FederatedRequest fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
            FederatedRequest fr5 = fedMap.cleanup(this.getTID(), fr3.getID());
            Future<FederatedResponse>[] ffr = fr2 != null ? fedMap.execute(this.getTID(), true, fr1, fr2, new FederatedRequest[]{fr3, fr4, fr5}) : fedMap.execute(this.getTID(), true, fr1, new FederatedRequest[]{fr3, fr4, fr5});
            ec.setMatrixOutput(this.output.getName(), CtableFEDInstruction.aggResult(ffr));
        }
    }

    private static boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
        boolean retVal;
        MatrixBlock mb = (MatrixBlock)mo2.acquireReadAndRelease();
        FederatedRange[] fedRanges = fedMap.getFederatedRanges();
        TreeMap fedDims = new TreeMap();
        IntStream.range(0, fedRanges.length).forEach(i -> {
            MatrixBlock sliced = mb.slice(fedRanges[i].getBeginDimsInt()[0], fedRanges[i].getEndDimsInt()[0] - 1, fedRanges[i].getBeginDimsInt()[1], fedRanges[i].getEndDimsInt()[1] - 1);
            fedDims.put(sliced.min(), sliced.max());
        });
        Iterator iter = fedDims.entrySet().iterator();
        Map.Entry entry = iter.next();
        double prevEndDim = (Double)entry.getValue();
        for (retVal = fedDims.size() == fedRanges.length ? true : false; iter.hasNext() && retVal; retVal &= prevEndDim < (Double)(entry = iter.next()).getKey()) {
            prevEndDim = (Double)entry.getValue();
        }
        return retVal;
    }

    private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, long staticDim, Long[] dims2, boolean reversed) {
        long d1 = reversed ? Collections.max(Arrays.asList(dims2)) : staticDim;
        long d2 = reversed ? staticDim : Collections.max(Arrays.asList(dims2));
        out.getDataCharacteristics().set(d1, d2, (int)mo1.getBlocksize(), mo1.getNnz());
        out.setFedMapping(fedMap);
        long varID = FederationUtils.getNextFedDataID();
        fedMap.mapParallel(varID, (range, data) -> {
            try {
                FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new SliceOutput(data.getVarID(), staticDim, dims2, reversed))).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        });
    }

    private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
        MatrixBlock resultBlock = new MatrixBlock(1, 1, true, 0L);
        int dim1 = 0;
        int dim2 = 0;
        for (int i = 0; i < ffr.length; ++i) {
            try {
                MatrixBlock mb = (MatrixBlock)ffr[i].get().getData()[0];
                dim1 = mb.getNumRows() > dim1 ? mb.getNumRows() : dim1;
                dim2 = mb.getNumColumns() > dim2 ? mb.getNumColumns() : dim2;
                MatrixBlock prev = new MatrixBlock(dim1, dim2, true, 0L);
                prev.copy(0, resultBlock.getNumRows() - 1, 0, resultBlock.getNumColumns() - 1, resultBlock, true);
                MatrixBlock next = new MatrixBlock(dim1, dim2, true, 0L);
                next.copy(0, mb.getNumRows() - 1, 0, mb.getNumColumns() - 1, mb, true);
                BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
                resultBlock = prev.binaryOperationsInPlace(plus, next);
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        return resultBlock;
    }

    private static FederationMap modifyFedRanges(FederationMap fedMap, long staticDim, Long[] dims2, boolean reversed) {
        IntStream.range(0, fedMap.getFederatedRanges().length).forEach(counter -> {
            FederatedRange fedRange = fedMap.getFederatedRanges()[counter];
            fedRange.setBeginDim(reversed ? 1 : 0, 0L);
            fedRange.setEndDim(reversed ? 1 : 0, staticDim);
            fedRange.setBeginDim(reversed ? 0 : 1, counter == 0 ? 0L : dims2[counter - 1]);
            fedRange.setEndDim(reversed ? 0 : 1, dims2[counter]);
        });
        return fedMap;
    }

    private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, CPOperand outOp, FederatedRange[] federatedRanges) {
        Long[] fedDims = new Long[federatedRanges.length];
        if (!in.isFederated()) {
            MatrixBlock mb = (MatrixBlock)in.acquireReadAndRelease();
            IntStream.range(0, federatedRanges.length).forEach(i -> {
                MatrixBlock sliced = mb.slice(federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] - 1, federatedRanges[i].getBeginDimsInt()[1], federatedRanges[i].getEndDimsInt()[1] - 1);
                fedDims[i] = (long)sliced.max();
            });
            return fedDims;
        }
        String maxInstString = this.constructMaxInstString(inOp.getName(), outOp.getName());
        FederationMap map = in.getFedMapping();
        FederatedRequest fr1 = FederationUtils.callInstruction(maxInstString, outOp, new CPOperand[]{inOp}, new long[]{in.getFedMapping().getID()});
        FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
        FederatedRequest fr3 = map.cleanup(this.getTID(), fr1.getID());
        Future<FederatedResponse>[] tmp = map.execute(this.getTID(), fr1, fr2, fr3);
        return CtableFEDInstruction.computeOutputDims(tmp);
    }

    private static Long[] computeOutputDims(Future<FederatedResponse>[] tmp) {
        Long[] fedDims = new Long[tmp.length];
        for (int i = 0; i < tmp.length; ++i) {
            try {
                fedDims[i] = ((ScalarObject)tmp[i].get().getData()[0]).getLongValue();
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        return fedDims;
    }

    private String constructMaxInstString(String in, String out) {
        String maxInstrString = this.instString.replace("ctable", "uamax");
        String[] instParts = maxInstrString.split("\u00b0");
        CharSequence[] maxInstParts = new String[]{instParts[0], instParts[1], InstructionUtils.concatOperandParts(in, Types.DataType.MATRIX.name(), Types.ValueType.FP64.name()), InstructionUtils.concatOperandParts(out, Types.DataType.SCALAR.name(), Types.ValueType.FP64.name()), "16"};
        return String.join((CharSequence)"\u00b0", maxInstParts);
    }

    private static class SliceOutput
    extends FederatedUDF {
        private static final long serialVersionUID = -2808597461054603816L;
        private final int _staticDim;
        private final Long[] _fedDims;
        private final boolean _reversed;

        protected SliceOutput(long input, long staticDim, Long[] fedDims, boolean reversed) {
            super(new long[]{input});
            this._staticDim = (int)staticDim;
            this._fedDims = fedDims;
            this._reversed = reversed;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject mo = (MatrixObject)data[0];
            MatrixBlock mb = (MatrixBlock)mo.acquireReadAndRelease();
            int beginDim = 0;
            int endDim = this._reversed ? mb.getNumRows() : mb.getNumColumns();
            int localStaticDim = this._reversed ? mb.getNumColumns() : mb.getNumRows();
            for (int counter = 0; counter < this._fedDims.length; ++counter) {
                if (this._fedDims[counter] != (long)endDim) continue;
                beginDim = counter == 0 ? 0 : this._fedDims[counter - 1].intValue();
                break;
            }
            mb = this.expandMatrix(mb, localStaticDim);
            MatrixBlock sliced = this._reversed ? mb.slice(beginDim, endDim - 1, 0, this._staticDim - 1) : mb.slice(0, this._staticDim - 1, beginDim, endDim - 1);
            mo.acquireModify(sliced);
            mo.release();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[0]);
        }

        private MatrixBlock expandMatrix(MatrixBlock mb, int localStaticDim) {
            int diff = this._staticDim - localStaticDim;
            if (diff > 0) {
                MatrixBlock tmpMb = this._reversed ? new MatrixBlock(mb.getNumRows(), diff, 0.0) : new MatrixBlock(diff, mb.getNumColumns(), 0.0);
                mb = mb.append(tmpMb, null, this._reversed);
            }
            return mb;
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

