/*
 * Decompiled with CFR 0.152.
 */
package jmarkov.jmdp.solvers;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import jmarkov.basic.Action;
import jmarkov.basic.Actions;
import jmarkov.basic.DecisionRule;
import jmarkov.basic.Policy;
import jmarkov.basic.Solution;
import jmarkov.basic.State;
import jmarkov.basic.StatesSet;
import jmarkov.basic.ValueFunction;
import jmarkov.basic.exceptions.SolverException;
import jmarkov.jmdp.CTMDP;
import jmarkov.jmdp.DTMDP;
import jmarkov.jmdp.InfiniteMDP;
import jmarkov.jmdp.solvers.AbstractDiscountedSolver;

public class ValueIterationSolver<S extends State, A extends Action>
extends AbstractDiscountedSolver<S, A> {
    private double epsilon = 1.0E-4;
    private double initVal = 0.0;
    private boolean useGaussSeidel = true;
    private boolean useErrorBounds = false;
    private boolean isAverage = false;
    private List<Double> difValues = new ArrayList<Double>();
    protected long processTime = 0L;
    protected long iterations;
    A bestAction;
    private double fixedRelativeCost = 0.0;
    private ValueFunction<S> relativeValueFunction = new ValueFunction();
    private double gain = 0.0;
    private boolean useModifiedAverage = false;
    boolean printBias = false;
    boolean printGain = false;

    ValueIterationSolver(DTMDP<S, A> problem, double interestRate, double initValue, double epsilon, boolean useGaussSeidel, boolean useErrorBounds, boolean isAverage, boolean useModifiedAverage) {
        super(problem, interestRate);
        this.initVal = initValue;
        this.epsilon = epsilon;
        this.useGaussSeidel = useGaussSeidel;
        this.useErrorBounds = useErrorBounds;
        this.isAverage = isAverage;
        this.useModifiedAverage = useModifiedAverage;
        problem.setSolver(this);
    }

    ValueIterationSolver(CTMDP<S, A> problem, double interestRate, double initValue, double epsilon, boolean useGaussSeidel, boolean useErrorBounds, boolean isAverage, boolean useModifiedAverage) {
        super(problem, interestRate);
        this.initVal = initValue;
        this.epsilon = epsilon;
        this.useGaussSeidel = useGaussSeidel;
        this.useErrorBounds = useErrorBounds;
        this.isAverage = isAverage;
        this.useModifiedAverage = useModifiedAverage;
        problem.setSolver(this);
    }

    public ValueIterationSolver(DTMDP<S, A> problem, double interestRate) {
        this(problem, interestRate, 0.0, 1.0E-4, true, false, false, false);
    }

    public ValueIterationSolver(CTMDP<S, A> problem, double interestRate) {
        this(problem, interestRate, 0.0, 0.001, true, false, false, false);
    }

    ValueIterationSolver(DTMDP<S, A> problem, boolean useModifiedAverage) {
        this(problem, 0.0, 0.0, 0.001, true, false, true, useModifiedAverage);
    }

    ValueIterationSolver(CTMDP<S, A> problem, boolean useModifiedAverage) {
        this(problem, 0.0, 0.0, 0.001, true, false, true, useModifiedAverage);
    }

    public synchronized void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    private void setInitVal(double val) {
        this.initVal = val;
    }

    public synchronized void useGaussSeidel(boolean val) {
        this.useGaussSeidel = val;
    }

    public final double getEpsilon() {
        return this.epsilon;
    }

    public final boolean isAverage() {
        return this.isAverage;
    }

    final boolean usesModifiedAverage() {
        return this.useModifiedAverage;
    }

    public final boolean usesErrorBounds() {
        return this.useErrorBounds;
    }

    public final boolean usesGaussSeidel() {
        return this.useGaussSeidel;
    }

    public synchronized void useErrorBounds(boolean val) {
        this.useErrorBounds = val;
    }

    private synchronized Solution<S, A> solve(int maxIterations) throws Exception {
        this.init();
        double actualDifference = Double.MAX_VALUE;
        long initialTime = 0L;
        this.iterations = 0L;
        initialTime = System.currentTimeMillis();
        double bound = (1.0 - this.discountFactor) * this.epsilon / (2.0 * this.discountFactor);
        if (this.isAverage && !this.useModifiedAverage) {
            bound = this.epsilon;
        }
        while (actualDifference > bound && this.iterations < (long)maxIterations) {
            actualDifference = this.useErrorBounds ? this.computeWithErrorBounds() : this.computeNoErrorBounds();
            this.difValues.add(new Double(actualDifference));
            this.getProblem().debug(3, "Max difference from previous value = " + actualDifference);
            ++this.iterations;
        }
        this.processTime = System.currentTimeMillis() - initialTime;
        if (this.isAverage) {
            this.relativeValueFunction = this.valueFunction;
            this.gain = this.fixedRelativeCost;
            this.valueFunction = this.buildValueFunction(this.valueFunction, this.fixedRelativeCost);
        }
        this.solved = true;
        return new Solution(this.valueFunction, this.policy);
    }

    @Override
    public Solution<S, A> solve() {
        try {
            return this.solve(Integer.MAX_VALUE);
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    protected void init() {
        StatesSet st = ((InfiniteMDP)this.getProblem()).getAllStates();
        for (State s : st) {
            this.valueFunction.set(s, this.initVal);
        }
    }

    protected double computeNoErrorBounds() {
        StatesSet st = ((InfiniteMDP)this.getProblem()).getAllStates();
        DecisionRule<State, A> decisionRuleCompute = this.iterations > 0L ? new DecisionRule(this.policy.getDecisionRule()) : new DecisionRule<State, A>();
        ValueFunction vf = new ValueFunction(this.valueFunction);
        Iterator it = vf.iterator();
        Iterator itGlobal = this.valueFunction.iterator();
        Iterator itDR = decisionRuleCompute.iterator();
        double maxDifference = 0.0;
        int n = 0;
        this.fixedRelativeCost = 0.0;
        for (State i : st) {
            double diff;
            this.bestAction = null;
            double newValFunction = this.bestAction(i);
            if (n == 0 && this.isAverage) {
                this.fixedRelativeCost = newValFunction;
            }
            if (this.useGaussSeidel && !this.isAverage) {
                Map.Entry entryGlobal = itGlobal.next();
                diff = Math.abs(newValFunction - entryGlobal.getValue());
                entryGlobal.setValue(newValFunction);
            } else {
                Map.Entry entry = it.next();
                if (this.isAverage) {
                    diff = Math.abs(newValFunction - entry.getValue() - this.fixedRelativeCost);
                    if (this.useModifiedAverage) {
                        entry.setValue(newValFunction - this.fixedRelativeCost + (1.0 - this.discountFactor) * entry.getValue());
                    } else {
                        entry.setValue(newValFunction - this.fixedRelativeCost);
                    }
                } else {
                    diff = Math.abs(newValFunction - entry.getValue());
                    entry.setValue(newValFunction);
                }
            }
            if (maxDifference < diff) {
                maxDifference = diff;
            }
            if (this.iterations > 0L) {
                Map.Entry decisionRuleEntry = null;
                decisionRuleEntry = itDR.next();
                decisionRuleEntry.setValue(this.bestAction);
            } else {
                decisionRuleCompute.set(i, this.bestAction);
            }
            ++n;
        }
        if (!this.useGaussSeidel || this.isAverage) {
            this.valueFunction = vf;
        }
        this.policy = new Policy(decisionRuleCompute);
        return maxDifference;
    }

    protected double computeWithErrorBounds() {
        StatesSet st = ((InfiniteMDP)this.getProblem()).getAllStates();
        DecisionRule<State, A> decisionRuleCompute = this.iterations > 0L ? new DecisionRule(this.policy.getDecisionRule()) : new DecisionRule<State, A>();
        ValueFunction vf = new ValueFunction(this.valueFunction);
        Iterator it = vf.iterator();
        Iterator itGlobal = this.valueFunction.iterator();
        Iterator itDR = decisionRuleCompute.iterator();
        double maxDifference = 0.0;
        double minDifference = Double.MAX_VALUE;
        int n = 0;
        this.fixedRelativeCost = 0.0;
        for (State i : st) {
            this.bestAction = null;
            double newValueFunction = this.bestAction(i);
            if (this.isAverage && n == 0) {
                this.fixedRelativeCost = newValueFunction;
            }
            double diff = 0.0;
            Map.Entry decisionRuleEntry = null;
            if (this.useGaussSeidel && !this.isAverage) {
                Map.Entry entryGlobal = itGlobal.next();
                diff = Math.abs(newValueFunction - entryGlobal.getValue());
                entryGlobal.setValue(newValueFunction);
            } else {
                Map.Entry entry = it.next();
                if (this.isAverage) {
                    diff = Math.abs(newValueFunction - entry.getValue() - this.fixedRelativeCost);
                    entry.setValue(newValueFunction - this.fixedRelativeCost);
                } else {
                    diff = Math.abs(newValueFunction - entry.getValue());
                    entry.setValue(newValueFunction);
                }
            }
            if (maxDifference < diff) {
                maxDifference = diff;
            }
            if (minDifference > diff) {
                minDifference = diff;
            }
            if (this.iterations > 0L) {
                decisionRuleEntry = itDR.next();
                decisionRuleEntry.setValue(this.bestAction);
            } else {
                decisionRuleCompute.set(i, this.bestAction);
            }
            ++n;
        }
        if (!this.useGaussSeidel || this.isAverage) {
            this.valueFunction = vf;
        }
        this.policy = new Policy(decisionRuleCompute);
        if (!this.isAverage) {
            return this.discountFactor / (1.0 - this.discountFactor) * (maxDifference - minDifference);
        }
        return maxDifference - minDifference;
    }

    protected double bestAction(S i) {
        Actions act = ((InfiniteMDP)this.getProblem()).feasibleActions(i);
        double val = 0.0;
        double minSoFar = Double.MAX_VALUE;
        for (Action a : act) {
            val = this.getProblem().operation(this.future(i, a, this.discountFactor), this.getDiscreteProblem().immediateCost(i, a));
            if (!(val < minSoFar)) continue;
            minSoFar = val;
            this.bestAction = a;
        }
        return minSoFar;
    }

    private ValueFunction<S> buildValueFunction(ValueFunction<S> vf, double optimum) {
        ValueFunction<State> result = new ValueFunction<State>(vf);
        Iterator<Map.Entry<S, Double>> it = vf.iterator();
        while (it.hasNext()) {
            Map.Entry<S, Double> m = it.next();
            result.set((State)m.getKey(), optimum);
        }
        return result;
    }

    private double difference(ValueFunction<S> vf, S i, double newValFunction) {
        double diff = Math.abs(newValFunction - vf.get(i));
        if (this.useGaussSeidel) {
            this.valueFunction.set(i, newValFunction);
        } else {
            vf.set(i, newValFunction);
        }
        return diff;
    }

    @Override
    public final long getProcessTime() {
        return this.processTime;
    }

    @Override
    public final long getIterations() {
        return this.iterations;
    }

    @Override
    public String label() {
        StringBuffer buf = new StringBuffer();
        buf.append("Value Iter. Solver ");
        if (this.isAverage) {
            buf.append(" (Avg)");
        } else {
            buf.append(" (Disc)");
        }
        return buf.toString();
    }

    @Override
    public String description() {
        StringBuffer buf = new StringBuffer();
        if (this.isAverage) {
            buf.append("Value Iteration Solver for Average Cost problem, ");
            if (this.useModifiedAverage) {
                buf.append("Factor = " + this.discountFactor);
            }
        } else {
            buf.append("Value Iteration Solver\n");
            buf.append("Discount Factor = " + this.discountFactor);
        }
        if (this.useGaussSeidel) {
            buf.append(",\nusing Gauss-Seidel modification\n");
        }
        if (this.useErrorBounds) {
            buf.append(",\nusing Error Bounds convergence\n");
        }
        return buf.toString();
    }

    public final double getGain() {
        return this.gain;
    }

    public final ValueFunction<S> getBias() {
        return this.relativeValueFunction;
    }

    public void setPrintBias(boolean val) {
        this.printBias = val;
    }

    public void setPrintGain(boolean val) {
        this.printGain = val;
    }

    @Override
    public void printSolution(PrintWriter pw) {
        pw.println(this);
        try {
            this.getOptimalPolicy().print(pw);
            if (this.printValueFunction) {
                this.valueFunction.print(pw);
            }
            if (this.printBias) {
                pw.println("Bias dor each state:");
            }
            this.relativeValueFunction.print(pw);
            if (this.printGain) {
                pw.println("Gain = " + this.gain);
            }
            if (this.printProcessTime) {
                pw.println("Process time = " + this.getProcessTime() + " milliseconds");
            }
        }
        catch (SolverException e) {
            pw.print(" Error solving the problem :" + e);
        }
    }

    @Override
    public void printSolution() throws Exception {
        this.printSolution(new PrintWriter(System.out));
    }
}

