/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.fedplanner;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.hops.fedplanner.AFederatedPlanner;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.fedplanner.FederatedCompilationTimer;
import org.apache.sysds.hops.fedplanner.FederatedPlannerUtils;
import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.utils.Explain;

public class FederatedPlannerCostbased
extends AFederatedPlanner {
    private static final Log LOG = LogFactory.getLog((String)FederatedPlannerCostbased.class.getName());
    private final MemoTable hopRelMemo = new MemoTable();
    private final Set<Long> hopRelUpdatedFinal = new HashSet<Long>();
    private final List<Hop> terminalHops = new ArrayList<Hop>();
    private final Map<String, Hop> transientWrites = new HashMap<String, Hop>();
    private LocalVariableMap localVariableMap = new LocalVariableMap();

    public List<Hop> getTerminalHops() {
        return this.terminalHops;
    }

    @Override
    public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        this.enumeratePlans(prog);
        this.selectPlan();
        this.updateExplain();
        FederatedCompilationTimer.activate();
    }

    private void enumeratePlans(DMLProgram prog) {
        FederatedCompilationTimer.startEnumerationTimer();
        prog.updateRepetitionEstimates();
        this.rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
        FederatedCompilationTimer.stopEnumerationTimer();
    }

    private void selectPlan() {
        FederatedCompilationTimer.startSelectPlanTimer();
        this.setFinalFedouts();
        FederatedCompilationTimer.stopSelectPlanTimer();
    }

    @Override
    public void rewriteFunctionDynamic(FunctionStatementBlock function, LocalVariableMap funcArgs) {
        this.localVariableMap = funcArgs;
        this.rewriteStatementBlock(function.getDMLProg(), function, null);
        this.setFinalFedouts();
        this.updateExplain();
    }

    private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram prog, List<StatementBlock> sbs, Map<String, Hop> paramMap) {
        ArrayList<StatementBlock> rewrittenStmBlocks = new ArrayList<StatementBlock>();
        for (StatementBlock stmBlock : sbs) {
            rewrittenStmBlocks.addAll(this.rewriteStatementBlock(prog, stmBlock, paramMap));
        }
        return rewrittenStmBlocks;
    }

    public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog, StatementBlock sb, Map<String, Hop> paramMap) {
        if (sb instanceof WhileStatementBlock) {
            return this.rewriteWhileStatementBlock(prog, (WhileStatementBlock)sb, paramMap);
        }
        if (sb instanceof IfStatementBlock) {
            return this.rewriteIfStatementBlock(prog, (IfStatementBlock)sb, paramMap);
        }
        if (sb instanceof ForStatementBlock) {
            return this.rewriteForStatementBlock(prog, (ForStatementBlock)sb, paramMap);
        }
        if (sb instanceof FunctionStatementBlock) {
            return this.rewriteFunctionStatementBlock(prog, (FunctionStatementBlock)sb, paramMap);
        }
        return this.rewriteDefaultStatementBlock(prog, sb, paramMap);
    }

    private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram prog, WhileStatementBlock whileSB, Map<String, Hop> paramMap) {
        Hop whilePredicateHop = whileSB.getPredicateHops();
        this.selectFederatedExecutionPlan(whilePredicateHop, paramMap);
        for (Statement stm : whileSB.getStatements()) {
            WhileStatement whileStm = (WhileStatement)stm;
            whileStm.setBody(this.rewriteStatementBlocks(prog, whileStm.getBody(), paramMap));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(whileSB));
    }

    private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram prog, IfStatementBlock ifSB, Map<String, Hop> paramMap) {
        this.selectFederatedExecutionPlan(ifSB.getPredicateHops(), paramMap);
        for (Statement statement : ifSB.getStatements()) {
            IfStatement ifStatement = (IfStatement)statement;
            ifStatement.setIfBody(this.rewriteStatementBlocks(prog, ifStatement.getIfBody(), paramMap));
            ifStatement.setElseBody(this.rewriteStatementBlocks(prog, ifStatement.getElseBody(), paramMap));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(ifSB));
    }

    private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram prog, ForStatementBlock forSB, Map<String, Hop> paramMap) {
        this.selectFederatedExecutionPlan(forSB.getFromHops(), paramMap);
        this.selectFederatedExecutionPlan(forSB.getToHops(), paramMap);
        this.selectFederatedExecutionPlan(forSB.getIncrementHops(), paramMap);
        DataIdentifier iterVar = ((ForStatement)forSB.getStatement(0)).getIterablePredicate().getIterVar();
        LocalVariableMap tmpLocalVariableMap = this.localVariableMap;
        this.localVariableMap = (LocalVariableMap)this.localVariableMap.clone();
        this.localVariableMap.put(iterVar.getName(), new IntObject(-1L));
        for (Statement statement : forSB.getStatements()) {
            ForStatement forStatement = (ForStatement)statement;
            forStatement.setBody(this.rewriteStatementBlocks(prog, forStatement.getBody(), paramMap));
        }
        this.localVariableMap = tmpLocalVariableMap;
        return new ArrayList<StatementBlock>(Collections.singletonList(forSB));
    }

    private ArrayList<StatementBlock> rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB, Map<String, Hop> paramMap) {
        for (Statement statement : funcSB.getStatements()) {
            FunctionStatement funcStm = (FunctionStatement)statement;
            funcStm.setBody(this.rewriteStatementBlocks(prog, funcStm.getBody(), paramMap));
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(funcSB));
    }

    private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb, Map<String, Hop> paramMap) {
        if (sb.hasHops()) {
            for (Hop sbHop : sb.getHops()) {
                this.selectFederatedExecutionPlan(sbHop, paramMap);
                if (!(sbHop instanceof FunctionOp)) continue;
                String funcName = ((FunctionOp)sbHop).getFunctionName();
                String funcNamespace = ((FunctionOp)sbHop).getFunctionNamespace();
                Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp)sbHop);
                if (paramMap != null && funcParamMap != null) {
                    funcParamMap.putAll(paramMap);
                }
                paramMap = funcParamMap;
                FunctionStatementBlock sbFuncBlock = prog.getFunctionDictionary(funcNamespace).getFunction(funcName);
                this.rewriteStatementBlock(prog, sbFuncBlock, paramMap);
                FunctionStatement funcStatement = (FunctionStatement)sbFuncBlock.getStatement(0);
                FederatedPlannerUtils.mapFunctionOutputs((FunctionOp)sbHop, funcStatement, this.transientWrites);
            }
        }
        return new ArrayList<StatementBlock>(Collections.singletonList(sb));
    }

    public void setFinalFedouts() {
        for (Hop root : this.terminalHops) {
            this.setFinalFedout(root);
        }
    }

    private void setFinalFedout(Hop root) {
        HopRel optimalRootHopRel = this.hopRelMemo.getMinCostAlternative(root);
        this.setFinalFedout(root, optimalRootHopRel);
    }

    private void setFinalFedout(Hop root, HopRel rootHopRel) {
        if (this.hopRelUpdatedFinal.contains(root.getHopID())) {
            if (rootHopRel.hasLocalOutput() ^ root.hasLocalOutput() && this.hopRelMemo.hasFederatedOutputAlternative(root)) {
                this.updateFederatedOutput(root, this.hopRelMemo.getFederatedOutputAlternative(root));
                root.activatePrefetch();
            } else {
                this.updateFederatedOutput(root, rootHopRel);
            }
        } else {
            this.updateFederatedOutput(root, rootHopRel);
            this.visitInputDependency(rootHopRel);
        }
    }

    private void visitInputDependency(HopRel rootHopRel) {
        List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
        for (HopRel input : hopRelInputs) {
            this.setFinalFedout(input.getHopRef(), input);
        }
    }

    private void updateFederatedOutput(Hop root, HopRel updateHopRel) {
        root.setFederatedOutput(updateHopRel.getFederatedOutput());
        root.setFederatedCost(updateHopRel.getCostObject());
        root.setForcedExecType(updateHopRel.getExecType());
        this.forceFixedFedOut(root);
        LOG.trace((Object)("Updated fedOut to " + updateHopRel.getFederatedOutput() + " for hop " + root.getHopID() + " opcode: " + root.getOpString()));
        this.hopRelUpdatedFinal.add(root.getHopID());
    }

    private void forceFixedFedOut(Hop root) {
        if (OptimizerUtils.FEDERATED_SPECS.containsKey(root.getBeginLine())) {
            FEDInstruction.FederatedOutput fedOutSpec = OptimizerUtils.FEDERATED_SPECS.get(root.getBeginLine());
            root.setFederatedOutput(fedOutSpec);
            if (fedOutSpec.isForcedFederated()) {
                root.deactivatePrefetch();
            }
        }
    }

    private void selectFederatedExecutionPlan(ArrayList<Hop> roots, Map<String, Hop> paramMap) {
        for (Hop root : roots) {
            this.selectFederatedExecutionPlan(root, paramMap);
        }
    }

    private void selectFederatedExecutionPlan(Hop root, Map<String, Hop> paramMap) {
        if (root != null) {
            this.visitFedPlanHop(root, paramMap);
            if (HopRewriteUtils.isTerminalHop(root)) {
                this.terminalHops.add(root);
            }
        }
    }

    private void visitFedPlanHop(Hop currentHop, Map<String, Hop> paramMap) {
        if (this.hopRelMemo.containsHop(currentHop)) {
            return;
        }
        this.debugLog(currentHop);
        for (Hop input : currentHop.getInput()) {
            this.visitFedPlanHop(input, paramMap);
        }
        ArrayList<HopRel> hopRels = this.getFedPlans(currentHop, paramMap);
        if (hopRels.isEmpty()) {
            hopRels.add(this.getNONEHopRel(currentHop, paramMap));
        }
        this.addTrace(hopRels);
        this.hopRelMemo.put(currentHop, hopRels);
    }

    private ArrayList<Hop> getHopInputs(Hop currentHop, Map<String, Hop> paramMap) {
        if (HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD)) {
            return FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, this.transientWrites, this.localVariableMap);
        }
        return currentHop.getInput();
    }

    private HopRel getNONEHopRel(Hop currentHop, Map<String, Hop> paramMap) {
        ArrayList<Hop> inputs = this.getHopInputs(currentHop, paramMap);
        HopRel noneHopRel = new HopRel(currentHop, FEDInstruction.FederatedOutput.NONE, this.hopRelMemo, inputs);
        FTypes.FType[] inputFType = (FTypes.FType[])noneHopRel.getInputDependency().stream().map(HopRel::getFType).toArray(FTypes.FType[]::new);
        FTypes.FType outputFType = this.getFederatedOut(currentHop, inputFType);
        noneHopRel.setFType(outputFType);
        return noneHopRel;
    }

    private ArrayList<HopRel> getFedPlans(Hop currentHop, Map<String, Hop> paramMap) {
        ArrayList<HopRel> hopRels = new ArrayList<HopRel>();
        ArrayList<Hop> inputHops = currentHop.getInput();
        if (HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) && (inputHops = FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, this.transientWrites, this.localVariableMap)) == null) {
            return this.createHopRelsFromRuntimeVars(currentHop, hopRels);
        }
        if (HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE)) {
            this.transientWrites.put(currentHop.getName(), currentHop);
        }
        if (HopRewriteUtils.isData(currentHop, Types.OpOpData.FEDERATED)) {
            hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.FOUT, this.deriveFType((DataOp)currentHop), this.hopRelMemo, inputHops));
        } else {
            hopRels.addAll(this.generateHopRels(currentHop, inputHops));
        }
        if (this.isLOUTSupported(currentHop)) {
            hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.LOUT, this.hopRelMemo, inputHops));
        }
        return hopRels;
    }

    private ArrayList<HopRel> createHopRelsFromRuntimeVars(Hop currentHop, ArrayList<HopRel> hopRels) {
        Data variable = this.localVariableMap.get(currentHop.getName());
        if (variable == null) {
            throw new DMLRuntimeException("Transient write not found for " + currentHop);
        }
        FederationMap fedMapping = null;
        if (variable instanceof CacheableData) {
            CacheableData cacheable = (CacheableData)variable;
            fedMapping = cacheable.getFedMapping();
        }
        if (fedMapping != null) {
            hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.FOUT, fedMapping.getType(), this.hopRelMemo, new ArrayList<Hop>()));
        } else {
            hopRels.add(new HopRel(currentHop, FEDInstruction.FederatedOutput.LOUT, this.hopRelMemo, new ArrayList<Hop>()));
        }
        return hopRels;
    }

    private Collection<HopRel> generateHopRels(Hop currentHop, List<Hop> inputHops) {
        List<List<FTypes.FType>> validFTypes = this.getValidFTypes(inputHops);
        List<List<FTypes.FType>> inputFTypeCombinations = this.getAllCombinations(validFTypes);
        HashMap<FTypes.FType, HopRel> foutHopRelMap = new HashMap<FTypes.FType, HopRel>();
        for (List<FTypes.FType> inputCombination : inputFTypeCombinations) {
            if (this.allowsFederated(currentHop, (FTypes.FType[])inputCombination.toArray(FTypes.FType[]::new))) {
                FTypes.FType outputFType = this.getFederatedOut(currentHop, inputCombination.toArray(new FTypes.FType[0]));
                if (outputFType != null) {
                    HopRel alt = new HopRel(currentHop, FEDInstruction.FederatedOutput.FOUT, outputFType, this.hopRelMemo, inputHops, inputCombination);
                    if (foutHopRelMap.containsKey((Object)alt.getFType())) {
                        foutHopRelMap.computeIfPresent(alt.getFType(), (key, currentVal) -> currentVal.getCost() < alt.getCost() ? currentVal : alt);
                        continue;
                    }
                    foutHopRelMap.put(outputFType, alt);
                    continue;
                }
                LOG.trace((Object)("Allows federated, but FOUT is not allowed: " + currentHop + " input FTypes: " + inputCombination));
                continue;
            }
            LOG.trace((Object)("Does not allow federated: " + currentHop + " input FTypes: " + inputCombination));
        }
        return foutHopRelMap.values();
    }

    private List<List<FTypes.FType>> getValidFTypes(List<Hop> inputHops) {
        ArrayList<List<FTypes.FType>> validFTypes = new ArrayList<List<FTypes.FType>>();
        for (Hop inputHop : inputHops) {
            validFTypes.add(this.hopRelMemo.getFTypes(inputHop));
        }
        return validFTypes;
    }

    public List<List<FTypes.FType>> getAllCombinations(List<List<FTypes.FType>> validFTypes) {
        ArrayList<List<FTypes.FType>> resultList = new ArrayList<List<FTypes.FType>>();
        this.buildCombinations(validFTypes, resultList, 0, new ArrayList<FTypes.FType>());
        return resultList;
    }

    public void buildCombinations(List<List<FTypes.FType>> validFTypes, List<List<FTypes.FType>> result, int currentIndex, List<FTypes.FType> currentResult) {
        if (currentIndex == validFTypes.size()) {
            result.add(currentResult);
        } else {
            for (FTypes.FType currentType : validFTypes.get(currentIndex)) {
                ArrayList<FTypes.FType> currentPass = new ArrayList<FTypes.FType>(currentResult);
                currentPass.add(currentType);
                this.buildCombinations(validFTypes, result, currentIndex + 1, currentPass);
            }
        }
    }

    public void updateExplain() {
        if (DMLScript.EXPLAIN == Explain.ExplainType.HOPS) {
            Explain.setMemo(this.hopRelMemo);
        }
    }

    private void debugLog(Hop currentHop) {
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Visiting HOP: " + currentHop + " Input size: " + currentHop.getInput().size()));
            if (currentHop.getPrivacy() != null) {
                LOG.debug((Object)currentHop.getPrivacy());
            }
            int index = 0;
            for (Hop hop : currentHop.getInput()) {
                if (hop == null) {
                    LOG.debug((Object)("Input at index is null: " + index));
                } else {
                    LOG.debug((Object)("HOP input: " + hop + " at index " + index + " of " + currentHop));
                }
                ++index;
            }
        }
    }

    private void addTrace(ArrayList<HopRel> hopRels) {
        if (LOG.isTraceEnabled()) {
            for (HopRel hr : hopRels) {
                LOG.trace((Object)("Adding to memo: " + hr));
            }
        }
    }

    private boolean isLOUTSupported(Hop associatedHop) {
        return associatedHop.getPrivacy() == null || !associatedHop.getPrivacy().hasConstraints();
    }
}

