/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IndexedIdentifier;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.parser.WhileStatementBlock;

public class RewriteInjectSparkLoopCheckpointing
extends StatementBlockRewriteRule {
    private boolean _checkCtx = false;

    public RewriteInjectSparkLoopCheckpointing(boolean checkParForContext) {
        this._checkCtx = checkParForContext;
    }

    @Override
    public boolean createsSplitDag() {
        return true;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) {
        if (!OptimizerUtils.isSparkExecutionMode()) {
            return Arrays.asList(sb);
        }
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        int blocksize = status.getBlocksize();
        if (!(!(sb instanceof WhileStatementBlock) && !(sb instanceof ForStatementBlock) || this._checkCtx && status.isInParforContext())) {
            ArrayList<String> candidates = new ArrayList<String>();
            VariableSet read = sb.variablesRead();
            VariableSet updated = sb.variablesUpdated();
            for (String rvar : read.getVariableNames()) {
                if (updated.containsVariable(rvar) || read.getVariable(rvar).getDataType() != Types.DataType.MATRIX && read.getVariable(rvar).getDataType() != Types.DataType.TENSOR) continue;
                candidates.add(rvar);
            }
            if (!candidates.isEmpty()) {
                StatementBlock sb0 = new StatementBlock();
                sb0.setDMLProg(sb.getDMLProg());
                sb0.setParseInfo(sb);
                ArrayList<Hop> hops = new ArrayList<Hop>();
                VariableSet livein = new VariableSet();
                VariableSet liveout = new VariableSet();
                for (String var : candidates) {
                    DataIdentifier dat = read.getVariable(var);
                    long dim1 = dat instanceof IndexedIdentifier ? ((IndexedIdentifier)dat).getOrigDim1() : dat.getDim1();
                    long dim2 = dat instanceof IndexedIdentifier ? ((IndexedIdentifier)dat).getOrigDim2() : dat.getDim2();
                    DataOp tread = new DataOp(var, Types.DataType.MATRIX, Types.ValueType.FP64, Types.OpOpData.TRANSIENTREAD, dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize);
                    tread.setRequiresCheckpoint(true);
                    DataOp twrite = HopRewriteUtils.createTransientWrite(var, tread);
                    hops.add(twrite);
                    livein.addVariable(var, read.getVariable(var));
                    liveout.addVariable(var, read.getVariable(var));
                }
                sb0.setHops(hops);
                sb0.setLiveIn(livein);
                sb0.setLiveOut(liveout);
                sb0.setSplitDag(true);
                ret.add(sb0);
                status.setInjectedCheckpoints();
            }
        }
        ret.add(sb);
        return ret;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
        return sbs;
    }
}

