/*
 * Decompiled with CFR 0.152.
 */
package org.joone.engine.extenders;

import org.joone.engine.ExtendableLearner;
import org.joone.engine.LearnableLayer;
import org.joone.engine.LearnableSynapse;
import org.joone.engine.RpropParameters;
import org.joone.engine.extenders.DeltaRuleExtender;
import org.joone.log.ILogger;
import org.joone.log.LoggerFactory;

public class RpropExtender
extends DeltaRuleExtender {
    private static final ILogger log = LoggerFactory.getLogger(RpropExtender.class);
    protected double[][] theDeltas;
    protected double[][] thePreviousGradients;
    protected RpropParameters theRpropParameters;
    protected double[][] theSummedGradients;

    public void reinit() {
        int n;
        int n2;
        Object object;
        LearnableLayer learnableLayer;
        ExtendableLearner extendableLearner = this.getLearner();
        if (extendableLearner.getMonitor().getLearningRate() != 1.0) {
            log.warn("RPROP learning rate should be equal to 1.");
        }
        if ((learnableLayer = extendableLearner.getLayer()) != null) {
            this.thePreviousGradients = new double[learnableLayer.getRows()][1];
            this.theSummedGradients = new double[this.thePreviousGradients.length][1];
            this.theDeltas = new double[this.thePreviousGradients.length][1];
        } else {
            object = extendableLearner.getSynapse();
            if (object != null) {
                n2 = object.getInputDimension();
                n = object.getOutputDimension();
                this.thePreviousGradients = new double[n2][n];
                this.theSummedGradients = new double[n2][n];
                this.theDeltas = new double[n2][n];
            }
        }
        object = this.getParameters();
        for (n2 = 0; n2 < this.theDeltas.length; ++n2) {
            for (n = 0; n < this.theDeltas[0].length; ++n) {
                this.theDeltas[n2][n] = ((RpropParameters)object).getInitialDelta(n2, n);
            }
        }
    }

    public double getDelta(double[] dArray, int n, double d) {
        double d2 = 0.0;
        double[] dArray2 = this.theSummedGradients[n];
        dArray2[0] = dArray2[0] - d;
        ExtendableLearner extendableLearner = this.getLearner();
        if (extendableLearner.getUpdateWeightExtender().storeWeightsBiases()) {
            RpropParameters rpropParameters = this.getParameters();
            if (this.thePreviousGradients[n][0] * this.theSummedGradients[n][0] > 0.0) {
                this.theDeltas[n][0] = Math.min(this.theDeltas[n][0] * rpropParameters.getEtaInc(), rpropParameters.getMaxDelta());
                d2 = -1.0 * this.sign(this.theSummedGradients[n][0]) * this.theDeltas[n][0];
                this.thePreviousGradients[n][0] = this.theSummedGradients[n][0];
            } else if (this.thePreviousGradients[n][0] * this.theSummedGradients[n][0] < 0.0) {
                this.theDeltas[n][0] = Math.max(this.theDeltas[n][0] * rpropParameters.getEtaDec(), rpropParameters.getMinDelta());
                d2 = -1.0 * extendableLearner.getLayer().getBias().delta[n][0];
                this.thePreviousGradients[n][0] = 0.0;
            } else {
                d2 = -1.0 * this.sign(this.theSummedGradients[n][0]) * this.theDeltas[n][0];
                this.thePreviousGradients[n][0] = this.theSummedGradients[n][0];
            }
            this.theSummedGradients[n][0] = 0.0;
        }
        return d2;
    }

    public double getDelta(double[] dArray, int n, double[] dArray2, int n2, double d) {
        double d2 = 0.0;
        double[] dArray3 = this.theSummedGradients[n];
        int n3 = n2;
        dArray3[n3] = dArray3[n3] - d;
        ExtendableLearner extendableLearner = this.getLearner();
        if (extendableLearner.getUpdateWeightExtender().storeWeightsBiases()) {
            RpropParameters rpropParameters = this.getParameters();
            if (this.thePreviousGradients[n][n2] * this.theSummedGradients[n][n2] > 0.0) {
                this.theDeltas[n][n2] = Math.min(this.theDeltas[n][n2] * rpropParameters.getEtaInc(), rpropParameters.getMaxDelta());
                d2 = -1.0 * this.sign(this.theSummedGradients[n][n2]) * this.theDeltas[n][n2];
                this.thePreviousGradients[n][n2] = this.theSummedGradients[n][n2];
            } else if (this.thePreviousGradients[n][n2] * this.theSummedGradients[n][n2] < 0.0) {
                this.theDeltas[n][n2] = Math.max(this.theDeltas[n][n2] * rpropParameters.getEtaDec(), rpropParameters.getMinDelta());
                d2 = -1.0 * extendableLearner.getSynapse().getWeights().delta[n][n2];
                this.thePreviousGradients[n][n2] = 0.0;
            } else {
                d2 = -1.0 * this.sign(this.theSummedGradients[n][n2]) * this.theDeltas[n][n2];
                this.thePreviousGradients[n][n2] = this.theSummedGradients[n][n2];
            }
            this.theSummedGradients[n][n2] = 0.0;
        }
        return d2;
    }

    public void postBiasUpdate(double[] dArray) {
    }

    public void postWeightUpdate(double[] dArray, double[] dArray2) {
    }

    public void preBiasUpdate(double[] dArray) {
        if (this.theDeltas == null || this.theDeltas.length != this.getLearner().getLayer().getRows()) {
            this.reinit();
        }
    }

    public void preWeightUpdate(double[] dArray, double[] dArray2) {
        LearnableSynapse learnableSynapse = this.getLearner().getSynapse();
        if (this.theDeltas == null || this.theDeltas.length != learnableSynapse.getInputDimension() || this.theDeltas[0].length != learnableSynapse.getOutputDimension()) {
            this.reinit();
        }
    }

    public RpropParameters getParameters() {
        if (this.theRpropParameters == null) {
            this.theRpropParameters = new RpropParameters();
        }
        return this.theRpropParameters;
    }

    public void setParameters(RpropParameters rpropParameters) {
        this.theRpropParameters = rpropParameters;
    }

    protected double sign(double d) {
        if (d > 0.0) {
            return 1.0;
        }
        if (d < 0.0) {
            return -1.0;
        }
        return 0.0;
    }
}

