/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup.scheme;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CompressionScheme {
    protected static final Log LOG = LogFactory.getLog((String)CompressionScheme.class.getName());
    private final ICLAScheme[] encodings;

    public CompressionScheme(ICLAScheme[] encodings) {
        this.encodings = encodings;
    }

    public ICLAScheme get(int i) {
        return this.encodings[i];
    }

    public CompressedMatrixBlock encode(MatrixBlock mb) {
        this.validateInput(mb);
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>(this.encodings.length);
        for (int i = 0; i < this.encodings.length; ++i) {
            ret.add(this.encodings[i].encode(mb));
        }
        return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret);
    }

    public CompressedMatrixBlock encode(MatrixBlock mb, int k) {
        if (k == 1) {
            return this.encode(mb);
        }
        this.validateInput(mb);
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList<EncodeTask> tasks = new ArrayList<EncodeTask>();
            for (int i = 0; i < this.encodings.length; ++i) {
                tasks.add(new EncodeTask(this.encodings[i], mb));
            }
            ArrayList<AColGroup> ret = new ArrayList<AColGroup>(this.encodings.length);
            for (Future t : pool.invokeAll(tasks)) {
                ret.add((AColGroup)t.get());
            }
            CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret);
            return compressedMatrixBlock;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed encoding", e);
        }
        finally {
            pool.shutdown();
        }
    }

    public CompressionScheme update(MatrixBlock mb) {
        this.validateInput(mb);
        for (int i = 0; i < this.encodings.length; ++i) {
            this.encodings[i] = this.encodings[i].update(mb);
        }
        return this;
    }

    public CompressionScheme update(MatrixBlock mb, int k) {
        if (k == 1) {
            return this.update(mb);
        }
        this.validateInput(mb);
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList<UpdateTask> tasks = new ArrayList<UpdateTask>();
            for (int i = 0; i < this.encodings.length; ++i) {
                tasks.add(new UpdateTask(this.encodings[i], mb));
            }
            List ret = pool.invokeAll(tasks);
            for (int i = 0; i < this.encodings.length; ++i) {
                this.encodings[i] = (ICLAScheme)ret.get(i).get();
            }
            CompressionScheme compressionScheme = this;
            return compressionScheme;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed encoding", e);
        }
        finally {
            pool.shutdown();
        }
    }

    public static CompressionScheme getScheme(CompressedMatrixBlock cmb) {
        if (cmb.isOverlapping()) {
            throw new DMLCompressionException("Invalid to extract CompressionScheme from an overlapping compression");
        }
        List<AColGroup> gs = cmb.getColGroups();
        ICLAScheme[] ret = new ICLAScheme[gs.size()];
        for (int i = 0; i < gs.size(); ++i) {
            ret[i] = gs.get(i).getCompressionScheme();
        }
        return new CompressionScheme(ret);
    }

    public CompressedMatrixBlock updateAndEncode(MatrixBlock mb, int k) {
        if (k == 1 || mb.getInMemorySize() < 160000L) {
            return this.updateAndEncode(mb);
        }
        this.validateInput(mb);
        int nRow = mb.getNumRows();
        int nCol = mb.getNumColumns();
        boolean transposed = false;
        if (CompressedMatrixBlockFactory.transposeHeuristics(this.encodings.length, mb)) {
            transposed = true;
            mb = LibMatrixReorg.transpose(mb, k, true);
        }
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            AColGroup[] ret = new AColGroup[this.encodings.length];
            ArrayList<UpdateAndEncodeTask> tasks = new ArrayList<UpdateAndEncodeTask>();
            int taskSize = Math.max(1, this.encodings.length / (4 * k));
            for (int i = 0; i < this.encodings.length; i += taskSize) {
                tasks.add(new UpdateAndEncodeTask(i, Math.min(this.encodings.length, i + taskSize), ret, mb, transposed));
            }
            for (Future t : pool.invokeAll(tasks)) {
                t.get();
            }
            ArrayList<AColGroup> retA = new ArrayList<AColGroup>(Arrays.asList(ret));
            CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(nRow, nCol, mb.getNonZeros(), false, retA);
            return compressedMatrixBlock;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed encoding", e);
        }
        finally {
            pool.shutdown();
        }
    }

    public CompressedMatrixBlock updateAndEncode(MatrixBlock mb) {
        this.validateInput(mb);
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>(this.encodings.length);
        boolean transposed = false;
        if (mb.getSparsity() < 0.1) {
            transposed = true;
            mb = LibMatrixReorg.transpose(mb, 1, true);
        }
        for (int i = 0; i < this.encodings.length; ++i) {
            ICLAScheme e = this.encodings[i];
            Pair<ICLAScheme, AColGroup> p = transposed ? e.updateAndEncodeT(mb) : e.updateAndEncode(mb);
            this.encodings[i] = p.getKey();
            ret.add(p.getValue());
        }
        return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret);
    }

    private void validateInput(MatrixBlock mb) {
        if (mb instanceof CompressedMatrixBlock) {
            throw new NotImplementedException("Not implemented schema encode/apply on an already compressed MatrixBlock");
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.getClass().getSimpleName());
        sb.append("\n");
        sb.append(Arrays.toString(this.encodings));
        return sb.toString();
    }

    protected class UpdateAndEncodeTask
    implements Callable<Object> {
        final int i;
        final int e;
        final MatrixBlock mb;
        final AColGroup[] ret;
        final boolean transposed;

        protected UpdateAndEncodeTask(int i, int e, AColGroup[] ret, MatrixBlock mb, boolean transposed) {
            this.i = i;
            this.e = e;
            this.mb = mb;
            this.ret = ret;
            this.transposed = transposed;
        }

        @Override
        public Object call() throws Exception {
            for (int j = this.i; j < this.e; ++j) {
                ICLAScheme sc = CompressionScheme.this.encodings[j];
                Pair<ICLAScheme, AColGroup> p = this.transposed ? sc.updateAndEncodeT(this.mb) : sc.updateAndEncode(this.mb);
                CompressionScheme.this.encodings[j] = p.getKey();
                this.ret[j] = p.getValue();
            }
            return null;
        }
    }

    protected class UpdateTask
    implements Callable<ICLAScheme> {
        final ICLAScheme enc;
        final MatrixBlock mb;

        protected UpdateTask(ICLAScheme enc, MatrixBlock mb) {
            this.enc = enc;
            this.mb = mb;
        }

        @Override
        public ICLAScheme call() throws Exception {
            return this.enc.update(this.mb);
        }
    }

    protected class EncodeTask
    implements Callable<AColGroup> {
        final ICLAScheme enc;
        final MatrixBlock mb;

        protected EncodeTask(ICLAScheme enc, MatrixBlock mb) {
            this.enc = enc;
            this.mb = mb;
        }

        @Override
        public AColGroup call() throws Exception {
            return this.enc.encode(this.mb);
        }
    }
}

