/*
 * Decompiled with CFR 0.152.
 */
package jmarkov.jmdp.solvers;

import java.util.Iterator;
import java.util.Map;
import jmarkov.basic.Action;
import jmarkov.basic.DecisionRule;
import jmarkov.basic.State;
import jmarkov.basic.ValueFunction;
import jmarkov.basic.exceptions.SolverException;
import jmarkov.jmdp.CT2DTConverter;
import jmarkov.jmdp.CTMDP;
import jmarkov.jmdp.DTMDP;

public class ProbabilitySolver<S extends State, A extends Action> {
    DecisionRule<S, A> dr;
    private ValueFunction<S> probability = null;
    DTMDP<S, A> problem;
    double epsilon = 1.0E-6;
    int size = 0;
    boolean gaussSeidel = false;
    boolean jacobi = true;
    long processTime = 0L;
    int iterations;
    boolean solved = false;
    int maxIterations = 100;

    public ProbabilitySolver(DTMDP<S, A> problem, DecisionRule<S, A> dr) {
        this.dr = dr;
        this.problem = problem;
    }

    public ProbabilitySolver(DTMDP<S, A> problem) throws SolverException {
        this.problem = problem;
        this.dr = problem.getOptimalPolicy().getDecisionRule();
    }

    public ProbabilitySolver(CTMDP<S, A> problem, DecisionRule<S, A> dr) {
        this.dr = dr;
        this.problem = new CT2DTConverter<S, A>(problem);
    }

    public ProbabilitySolver(CTMDP<S, A> problem) throws SolverException {
        this.problem = new CT2DTConverter<S, A>(problem);
        this.dr = problem.getOptimalPolicy().getDecisionRule();
    }

    public boolean isSolved() {
        return this.solved;
    }

    public void solve() {
        if (this.jacobi) {
            this.solveJacobi();
        } else {
            this.solvePower();
        }
        this.solved = true;
    }

    private ValueFunction<S> solveJacobi() {
        ValueFunction<S> oldVals = this.initializeProbs();
        ValueFunction<S> newVals = null;
        double maxDifference = Double.MAX_VALUE;
        this.iterations = 0;
        long initialTime = System.currentTimeMillis();
        while (maxDifference > this.epsilon && this.iterations < this.maxIterations) {
            newVals = new ValueFunction<S>(oldVals);
            ValueFunction<S> piTimesP = this.piTimesP(oldVals);
            Iterator<Map.Entry<S, Double>> oldIt = oldVals.iterator();
            Iterator<Map.Entry<S, Double>> newIt = newVals.iterator();
            Iterator<Map.Entry<S, Double>> piXpIt = piTimesP.iterator();
            Iterator<Map.Entry<S, A>> drIt = this.dr.iterator();
            maxDifference = 0.0;
            while (oldIt.hasNext()) {
                double diff;
                Map.Entry<S, Double> oldE = oldIt.next();
                Map.Entry<S, Double> piXpE = piXpIt.next();
                Map.Entry<S, A> drE = drIt.next();
                Map.Entry<S, Double> newE = newIt.next();
                State i = (State)oldE.getKey();
                double pii = this.problem.prob(i, i, (Action)drE.getValue());
                if (pii == 1.0) {
                    System.out.println("State " + i + " is an absorbing state under action" + drE.getValue());
                }
                double oldVal = oldE.getValue();
                double newVal = (-piXpE.getValue().doubleValue() + oldVal * pii) / (pii - 1.0);
                if (Math.abs(-piXpE.getValue().doubleValue() + oldVal * pii) < 1.0E-10) {
                    newVal = 0.0;
                }
                if (maxDifference < (diff = Math.abs(newVal - oldVal) / oldVal)) {
                    maxDifference = diff;
                }
                if (this.gaussSeidel) {
                    oldE.setValue(newVal);
                    continue;
                }
                newE.setValue(newVal);
            }
            if (!this.gaussSeidel) {
                oldVals = newVals;
            }
            ++this.iterations;
        }
        this.processTime = System.currentTimeMillis() - initialTime;
        this.probability = newVals;
        this.problem.debug(1, "Probability convergence in " + this.iterations + " iterations");
        this.problem.debug(1, "Probability convergence in " + this.processTime + " milliseconds");
        return newVals;
    }

    private ValueFunction<S> solvePower() {
        ValueFunction<S> readPi = this.initializeProbs();
        ValueFunction<S> writePi = null;
        double maxDifference = Double.MAX_VALUE;
        int iterations = 0;
        long initialTime = System.currentTimeMillis();
        while (maxDifference > this.epsilon) {
            writePi = this.piTimesP(readPi);
            maxDifference = this.difference(readPi, writePi);
            readPi = writePi;
            ++iterations;
        }
        this.probability = writePi;
        this.processTime = System.currentTimeMillis() - initialTime;
        this.problem.debug(1, "Probability convergence in " + iterations + " iterations");
        this.problem.debug(1, "Probability convergence in " + this.processTime + " milliseconds");
        return writePi;
    }

    private ValueFunction<S> piTimesP(ValueFunction<S> readPi) {
        ValueFunction<S> writePi = this.initializeVF();
        Iterator<Map.Entry<S, Double>> readIt = readPi.iterator();
        while (readIt.hasNext()) {
            Iterator<Map.Entry<S, Double>> itW = writePi.iterator();
            Map.Entry<S, Double> readE = readIt.next();
            State i = (State)readE.getKey();
            for (State j : this.problem.reachable(i, this.dr.getAction(i))) {
                Map.Entry<S, Double> entryW = itW.next();
                while (!((State)entryW.getKey()).equals(j)) {
                    entryW = itW.next();
                }
                entryW.setValue(entryW.getValue() + readE.getValue() * this.problem.prob(i, j, this.dr.getAction(i)));
            }
        }
        return writePi;
    }

    private double difference(ValueFunction<S> oldPi, ValueFunction<S> newPi) {
        Iterator<Map.Entry<S, Double>> oldIt = oldPi.iterator();
        Iterator<Map.Entry<S, Double>> newIt = newPi.iterator();
        double maxDifference = 0.0;
        while (oldIt.hasNext()) {
            double oldVal = oldIt.next().getValue();
            double diff = Math.abs(newIt.next().getValue() - oldVal) / oldVal;
            if (!(maxDifference < diff)) continue;
            maxDifference = diff;
        }
        return maxDifference;
    }

    public void setGaussSeidel(boolean val) {
        this.gaussSeidel = val;
    }

    public void setJacobi(boolean val) {
        this.jacobi = val;
    }

    ValueFunction<S> initializeVF() {
        Iterator<Map.Entry<S, A>> it = this.dr.iterator();
        ValueFunction<State> probability = new ValueFunction<State>("Steady state probabilities");
        this.size = 0;
        while (it.hasNext()) {
            Map.Entry<S, A> entry = it.next();
            probability.set((State)entry.getKey(), 0.0);
            ++this.size;
        }
        return probability;
    }

    ValueFunction<S> initializeProbs() {
        ValueFunction<S> probability = this.initializeVF();
        Iterator<Map.Entry<S, Double>> it = probability.iterator();
        while (it.hasNext()) {
            Map.Entry<S, Double> entry = it.next();
            entry.setValue(1.0 / (double)this.size);
        }
        return probability;
    }

    public ValueFunction<S> getProbability() {
        return this.probability;
    }
}

