/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.Serializable;
import java.util.LinkedList;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.utils.stats.ParamServStatistics;
import scala.Tuple2;

public class SparkParamservUtils {
    public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) {
        JavaPairRDD<Long, MatrixBlock> fRDD = SparkParamservUtils.groupMatrix(featuresRDD);
        JavaPairRDD<Long, MatrixBlock> lRDD = SparkParamservUtils.groupMatrix(labelsRDD);
        return fRDD.join(lRDD);
    }

    private static JavaPairRDD<Long, MatrixBlock> groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) {
        return rdd.mapToPair((PairFunction & Serializable)input -> new Tuple2((Object)((MatrixIndexes)input._1).getRowIndex(), (Object)new Tuple2((Object)((MatrixIndexes)input._1).getColumnIndex(), (Object)((MatrixBlock)input._2)))).aggregateByKey(new LinkedList(), (Function2 & Serializable)(list, input) -> {
            list.add(input);
            return list;
        }, (Function2 & Serializable)(l1, l2) -> {
            l1.addAll(l2);
            l1.sort((o1, o2) -> ((Long)o1._1).compareTo((Long)o2._1));
            return l1;
        }).mapToPair((PairFunction & Serializable)input -> {
            LinkedList list = (LinkedList)input._2;
            MatrixBlock result = (MatrixBlock)((Tuple2)list.get((int)0))._2;
            for (int i = 1; i < list.size(); ++i) {
                result = ParamservUtils.cbindMatrix(result, (MatrixBlock)((Tuple2)list.get((int)i))._2);
            }
            return new Tuple2((Object)((Long)input._1), (Object)result);
        });
    }

    public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, MatrixObject labels, Statement.PSScheme scheme, final int workerNum) {
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = sec.getRDDHandleForMatrixObject(features, Types.FileFormat.BINARY);
        JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD = sec.getRDDHandleForMatrixObject(labels, Types.FileFormat.BINARY);
        DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int)features.getNumRows());
        JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = SparkParamservUtils.assembleTrainingData(featuresRDD, labelsRDD).flatMapToPair(mapper).aggregateByKey(new LinkedList(), new Partitioner(){
            private static final long serialVersionUID = -7937781374718031224L;

            @Override
            public int getPartition(Object workerID) {
                return (Integer)workerID;
            }

            @Override
            public int numPartitions() {
                return workerNum;
            }
        }, (Function2 & Serializable)(list, input) -> {
            list.add(input);
            return list;
        }, (Function2 & Serializable)(l1, l2) -> {
            l1.addAll(l2);
            l1.sort((o1, o2) -> ((Long)o1._1).compareTo((Long)o2._1));
            return l1;
        }).mapToPair(new DataPartitionerSparkAggregator(features.getNumColumns(), labels.getNumColumns()));
        if (DMLScript.STATISTICS) {
            ParamServStatistics.accSetupTime((long)tSetup.stop());
        }
        return result;
    }
}

