/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.BinaryM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
import org.apache.sysds.runtime.instructions.fed.AppendFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.BinaryMatrixMatrixFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.BinaryMatrixScalarFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.CovarianceFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.CumulativeOffsetFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.MMFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuantilePickFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.CovarianceSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

public abstract class BinaryFEDInstruction
extends ComputationFEDInstruction {
    protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FEDInstruction.FederatedOutput fedOut) {
        super(type, op, in1, in2, out, opcode, istr, fedOut);
    }

    protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        this(type, op, in1, in2, out, opcode, istr, FEDInstruction.FederatedOutput.NONE);
    }

    public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
        super(type, op, in1, in2, in3, out, opcode, istr);
    }

    public static BinaryFEDInstruction parseInstruction(BinaryCPInstruction inst, ExecutionContext ec) {
        if (inst.input1.isMatrix() && ec.getMatrixObject(inst.input1).isFederatedExcept(FTypes.FType.BROADCAST) || inst.input2 != null && inst.input2.isMatrix() && ec.getMatrixObject(inst.input2).isFederatedExcept(FTypes.FType.BROADCAST)) {
            if (inst instanceof AppendCPInstruction) {
                return AppendFEDInstruction.parseInstruction((AppendCPInstruction)inst);
            }
            if (inst instanceof QuantilePickCPInstruction) {
                return QuantilePickFEDInstruction.parseInstruction((QuantilePickCPInstruction)inst);
            }
            if (inst instanceof CovarianceCPInstruction && (ec.getMatrixObject(inst.input1).isFederated(FTypes.FType.ROW) || ec.getMatrixObject(inst.input2).isFederated(FTypes.FType.ROW))) {
                return CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction)inst);
            }
            if (inst instanceof BinaryMatrixMatrixCPInstruction) {
                return BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixCPInstruction)inst);
            }
            if (inst instanceof BinaryMatrixScalarCPInstruction) {
                return BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarCPInstruction)inst);
            }
        }
        return null;
    }

    public static BinaryFEDInstruction parseInstruction(BinarySPInstruction inst, ExecutionContext ec) {
        if (inst instanceof MapmmSPInstruction || inst instanceof CpmmSPInstruction || inst instanceof RmmSPInstruction) {
            Data data = ec.getVariable(inst.input1);
            if (data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return MMFEDInstruction.parseInstruction((AggregateBinarySPInstruction)inst);
            }
        } else if (inst instanceof QuantilePickSPInstruction) {
            QuantilePickSPInstruction qinstruction = (QuantilePickSPInstruction)inst;
            Data data = ec.getVariable(qinstruction.input1);
            if (data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return QuantilePickFEDInstruction.parseInstruction(qinstruction);
            }
        } else if (inst instanceof AppendGAlignedSPInstruction || inst instanceof AppendGSPInstruction || inst instanceof AppendMSPInstruction || inst instanceof AppendRSPInstruction) {
            Data data1 = ec.getVariable(inst.input1);
            Data data2 = ec.getVariable(inst.input2);
            if (data1 instanceof MatrixObject && ((MatrixObject)data1).isFederatedExcept(FTypes.FType.BROADCAST) || data2 instanceof MatrixObject && ((MatrixObject)data2).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return AppendFEDInstruction.parseInstruction((AppendSPInstruction)inst);
            }
        } else if (inst instanceof BinaryMatrixScalarSPInstruction) {
            Data data = ec.getVariable(inst.input1);
            if (data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarSPInstruction)inst);
            }
        } else if (inst instanceof BinaryMatrixMatrixSPInstruction) {
            Data data = ec.getVariable(inst.input1);
            if (data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixSPInstruction)inst);
            }
        } else if (inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederatedExcept(FTypes.FType.BROADCAST) || inst.input2.isMatrix() && ec.getMatrixObject(inst.input2).isFederatedExcept(FTypes.FType.BROADCAST)) {
            if (inst instanceof CovarianceSPInstruction && (ec.getMatrixObject(inst.input1).isFederated(FTypes.FType.ROW) || ec.getMatrixObject(inst.input2).isFederated(FTypes.FType.ROW))) {
                return CovarianceFEDInstruction.parseInstruction((CovarianceSPInstruction)inst);
            }
            if (inst instanceof CumulativeOffsetSPInstruction) {
                return CumulativeOffsetFEDInstruction.parseInstruction((CumulativeOffsetSPInstruction)inst);
            }
            return BinaryFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(), FEDInstruction.FederatedOutput.NONE.name()));
        }
        return null;
    }

    public static BinaryFEDInstruction parseInstruction(String str) {
        if (str.startsWith(Types.ExecType.SPARK.name())) {
            str = BinaryFEDInstruction.rewriteSparkInstructionToCP(str);
        }
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 3, 4, 5, 6);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        FEDInstruction.FederatedOutput fedOut = FEDInstruction.FederatedOutput.valueOf(parts[parts.length - 1]);
        BinaryFEDInstruction.checkOutputDataType(in1, in2, out);
        MultiThreadedOperator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
        if (in1.getDataType() == Types.DataType.SCALAR && in2.getDataType() == Types.DataType.SCALAR) {
            throw new DMLRuntimeException("Federated binary scalar scalar operations not yet supported");
        }
        if (in1.getDataType() == Types.DataType.MATRIX && in2.getDataType() == Types.DataType.MATRIX) {
            return new BinaryMatrixMatrixFEDInstruction(operator, in1, in2, out, opcode, str, fedOut);
        }
        if (in1.getDataType() == Types.DataType.TENSOR && in2.getDataType() == Types.DataType.TENSOR) {
            throw new DMLRuntimeException("Federated binary tensor tensor operations not yet supported");
        }
        if (in1.isMatrix() && in2.isScalar() || in2.isMatrix() && in1.isScalar()) {
            return new BinaryMatrixScalarFEDInstruction(operator, in1, in2, out, opcode, str, fedOut);
        }
        throw new DMLRuntimeException("Federated binary operations not yet supported:" + opcode);
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 3, 4);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        out.split(parts[3]);
        return opcode;
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 4);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        in3.split(parts[3]);
        out.split(parts[4]);
        return opcode;
    }

    protected static void checkOutputDataType(CPOperand in1, CPOperand in2, CPOperand out) {
        if ((in1.getDataType() == Types.DataType.MATRIX || in2.getDataType() == Types.DataType.MATRIX) && out.getDataType() != Types.DataType.MATRIX) {
            throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() + " and " + in2.getName() + " must produce a matrix, which " + out.getName() + " is not");
        }
    }

    protected static String rewriteSparkInstructionToCP(String inst_str) {
        inst_str = inst_str.replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name());
        inst_str = inst_str.replace("\u00b0map", "\u00b0");
        inst_str = inst_str.replace("\u00b0RIGHT", "");
        inst_str = inst_str.replace("\u00b0" + BinaryM.VectorType.ROW_VECTOR.name(), "");
        inst_str = inst_str.replace("\u00b0" + BinaryM.VectorType.COL_VECTOR.name(), "");
        return inst_str;
    }
}

