/*
 * Decompiled with CFR 0.152.
 */
package jmarkov.jmdp.solvers;

import java.util.Iterator;
import java.util.Map;
import jmarkov.basic.Action;
import jmarkov.basic.Actions;
import jmarkov.basic.DecisionRule;
import jmarkov.basic.Policy;
import jmarkov.basic.Solution;
import jmarkov.basic.State;
import jmarkov.basic.States;
import jmarkov.basic.StatesSet;
import jmarkov.basic.ValueFunction;
import jmarkov.basic.exceptions.NonStochasticException;
import jmarkov.basic.exceptions.SolverException;
import jmarkov.jmdp.DTMDP;
import jmarkov.jmdp.InfiniteMDP;
import jmarkov.jmdp.solvers.AbstractDiscountedSolver;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.sparse.BiCGstab;
import no.uib.cipr.matrix.sparse.FlexCompRowMatrix;
import no.uib.cipr.matrix.sparse.IterativeSolverNotConvergedException;
import no.uib.cipr.matrix.sparse.SparseVector;

public class PolicyIterationSolver<S extends State, A extends Action>
extends AbstractDiscountedSolver<S, A> {
    private DenseVector costs;
    private boolean isOptimal = false;
    private boolean modifiedPolicy = false;
    private long initialIterations = 20L;
    private double increasingFactor = 1.1;
    private long maxIterations;
    protected long iterations;
    protected long processTime = 0L;
    private double epsilon = 1.0E-4;
    private boolean gaussSeidel = true;
    private boolean errorBounds = false;
    private DenseVector vecValueFunction = null;
    private FlexCompRowMatrix matrix = null;
    private DecisionRule<S, A> currentDecisionRule = null;

    public PolicyIterationSolver(DTMDP<S, A> problem, double discountFactor) {
        this(problem, discountFactor, false);
    }

    public PolicyIterationSolver(DTMDP<S, A> problem, double discountFactor, boolean setModifiedPolicy) {
        super(problem, discountFactor);
        this.modifiedPolicy = setModifiedPolicy;
    }

    public double getIncreasingFactor() {
        return this.increasingFactor;
    }

    public void setIncreasingFactor(double increasingFactor) {
        this.increasingFactor = increasingFactor;
    }

    public double getInitialIterations() {
        return this.initialIterations;
    }

    public void setInitialIterations(int initialIterations) {
        this.initialIterations = initialIterations;
    }

    @Override
    public Solution<S, A> solve() throws SolverException {
        long initialTime = System.currentTimeMillis();
        this.currentDecisionRule = this.initialDecisionRuleMyopic();
        this.policy = new Policy(this.currentDecisionRule);
        this.vecValueFunction = new DenseVector(this.getDiscreteProblem().getNumStates());
        this.matrix = this.buildMatrix(this.currentDecisionRule);
        this.iterations = 0L;
        while (!this.isOptimal) {
            this.problem.debug(2, "Iteration " + this.iterations);
            this.getProblem().debug(3, "Current Rule = " + this.currentDecisionRule);
            this.valueFunction = this.policyEvaluation();
            this.getProblem().debug(3, "Current Value function = " + this.valueFunction);
            this.currentDecisionRule = this.policyImprovement();
            ++this.iterations;
        }
        if (this.modifiedPolicy) {
            this.matrix = this.buildMatrix(this.currentDecisionRule);
            this.valueFunction = this.solveMatrix();
        }
        this.policy = new Policy(this.currentDecisionRule);
        this.solved = true;
        this.processTime = System.currentTimeMillis() - initialTime;
        return new Solution(this.valueFunction, this.policy);
    }

    private DecisionRule<S, A> initialDecisionRuleFirst() {
        this.valueFunction = new ValueFunction();
        DecisionRule<State, Action> localDecisionRule = new DecisionRule<State, Action>();
        StatesSet states = ((InfiniteMDP)this.getProblem()).getAllStates();
        for (State i : states) {
            Actions availableActions = ((InfiniteMDP)this.getProblem()).feasibleActions(i);
            localDecisionRule.set(i, (Action)availableActions.iterator().next());
            this.valueFunction.set(i, 0.0);
        }
        return localDecisionRule;
    }

    private DecisionRule<S, A> initialDecisionRuleMyopic() {
        this.valueFunction = new ValueFunction();
        DecisionRule<State, Action> localDecisionRule = new DecisionRule<State, Action>();
        StatesSet states = ((InfiniteMDP)this.getProblem()).getAllStates();
        Action bestAction = null;
        double bestVal = Double.MAX_VALUE;
        for (State state : states) {
            Actions availableActions = ((InfiniteMDP)this.getProblem()).feasibleActions(state);
            for (Action action : availableActions) {
                double val = this.getDiscreteProblem().immediateCost(state, action);
                if (!(val < bestVal)) continue;
                bestAction = action;
                bestVal = val;
            }
            localDecisionRule.set(state, bestAction);
            this.valueFunction.set(state, bestVal);
        }
        return localDecisionRule;
    }

    private ValueFunction<S> policyEvaluation() throws SolverException {
        this.valueFunction = this.modifiedPolicy ? this.solveMatrixModified(this.currentDecisionRule) : this.solveMatrix();
        return this.valueFunction;
    }

    private DecisionRule<S, A> policyImprovement() throws SolverException {
        StatesSet sts = ((InfiniteMDP)this.getProblem()).getAllStates();
        DecisionRule<S, A> newDecisionRule = new DecisionRule<S, A>(this.currentDecisionRule);
        Iterator<Map.Entry<S, A>> itCurDR = this.currentDecisionRule.iterator();
        Iterator<Map.Entry<S, A>> itNewDR = newDecisionRule.iterator();
        for (State i : sts) {
            Map.Entry<S, A> curDRentry;
            Action curAction;
            Actions actions = ((InfiniteMDP)this.getProblem()).feasibleActions(i);
            Action bestAction = null;
            double bestValue = Double.MAX_VALUE;
            for (Action a : actions) {
                double val = this.getProblem().operation(this.getDiscreteProblem().immediateCost(i, a), this.future(i, a, this.discountFactor));
                if (!(val < bestValue)) continue;
                bestValue = val;
                bestAction = a;
            }
            if (!this.modifiedPolicy && !bestAction.equals(curAction = (Action)(curDRentry = itCurDR.next()).getValue())) {
                this.matrix.setRow(i.getIndex(), this.buildRowVector(i, bestAction));
                this.costs.set(i.getIndex(), this.getDiscreteProblem().immediateCost(i, bestAction));
            }
            Map.Entry<S, A> newDRentry = itNewDR.next();
            newDRentry.setValue(bestAction);
        }
        this.isOptimal = this.currentDecisionRule.equals(newDecisionRule);
        return newDecisionRule;
    }

    private SparseVector buildRowVector(S i, A a) {
        int n = this.getDiscreteProblem().getNumStates();
        States reachableStates = this.getDiscreteProblem().reachable(i, a);
        SparseVector vec = new SparseVector(n, reachableStates.size());
        double sum = 0.0;
        for (State j : reachableStates) {
            double probability = this.getDiscreteProblem().prob((State)i, j, a);
            sum += probability;
            assert (probability >= 0.0);
            j = this.getDiscreteProblem().getAllStates().get(j);
            if (!(probability > 0.0)) continue;
            vec.set(j.getIndex(), probability);
        }
        vec.scale(-this.discountFactor);
        vec.add(((State)i).getIndex(), 1.0);
        if (Math.abs(sum - 1.0) > 1.0E-5) {
            throw new NonStochasticException("Probabilities do not add up to 1 for state " + i + ", and action " + a + ", sum = " + sum);
        }
        return vec;
    }

    private FlexCompRowMatrix buildMatrix(DecisionRule<S, A> currentDecisionRule) {
        StatesSet stts = this.getDiscreteProblem().getAllStates();
        int n = stts.size();
        FlexCompRowMatrix matrix = new FlexCompRowMatrix(n, n);
        this.costs = new DenseVector(n);
        for (State i : stts) {
            A a = currentDecisionRule.getAction(i);
            matrix.setRow(i.getIndex(), this.buildRowVector(i, a));
            this.costs.set(i.getIndex(), this.getDiscreteProblem().immediateCost(i, a));
        }
        return matrix;
    }

    protected ValueFunction<S> solveMatrix() throws SolverException {
        this.getProblem().debug(4, "Matrix to solve:\n" + this.matrix);
        try {
            BiCGstab solver = new BiCGstab((Vector)this.vecValueFunction);
            solver.solve((Matrix)this.matrix, (Vector)this.costs, (Vector)this.vecValueFunction);
        }
        catch (IterativeSolverNotConvergedException e) {
            throw new SolverException("Policy iteration Solver: error solving linear system.", e);
        }
        return this.buildValueFunction(this.vecValueFunction);
    }

    protected ValueFunction<S> solveMatrixModified(DecisionRule<S, A> localDecisionRule) {
        StatesSet st = ((InfiniteMDP)this.getProblem()).getAllStates();
        ValueFunction vf = new ValueFunction(this.valueFunction);
        ValueFunction vf2 = new ValueFunction(this.valueFunction);
        double maxDifference = 0.0;
        int localIterations = 0;
        boolean toContinue = true;
        while (toContinue) {
            Iterator it = vf.iterator();
            Iterator it2 = vf2.iterator();
            maxDifference = 0.0;
            double bound = (1.0 - this.discountFactor) * this.epsilon / (2.0 * this.discountFactor);
            for (State i : st) {
                A a = localDecisionRule.getAction(i);
                double val = this.getProblem().operation(this.getDiscreteProblem().immediateCost(i, a), this.future(i, a, this.discountFactor, vf));
                Map.Entry entry = it.next();
                Map.Entry entry2 = it2.next();
                double diff = Math.abs(val - entry.getValue());
                if (maxDifference < diff) {
                    maxDifference = diff;
                }
                if (this.gaussSeidel) {
                    entry.setValue(val);
                    continue;
                }
                entry2.setValue(val);
            }
            ++localIterations;
            if (!this.gaussSeidel) {
                vf = vf2;
            }
            boolean bl = toContinue = (long)localIterations < this.initialIterations;
            if (!this.isOptimal) continue;
            boolean bl2 = toContinue = maxDifference > bound;
        }
        this.initialIterations = (long)Math.ceil(this.increasingFactor * (double)this.initialIterations);
        return vf;
    }

    private ValueFunction<S> buildValueFunction(DenseVector vec) {
        ValueFunction<State> vf = new ValueFunction<State>();
        StatesSet stts = this.getDiscreteProblem().getAllStates();
        int i = 0;
        for (State s : stts) {
            vf.set(s, vec.get(i));
            ++i;
        }
        return vf;
    }

    public void setModifiedPolicy(boolean val) {
        this.modifiedPolicy = val;
    }

    @Override
    public String description() {
        return this.modifiedPolicy ? "Modified " : "Policy Iteration Solver\nDiscount Factor = " + this.discountFactor;
    }

    @Override
    public String label() {
        return "Policy Iter. Solver(disc)";
    }

    @Override
    public final long getProcessTime() {
        return this.processTime;
    }

    @Override
    public final long getIterations() {
        return this.iterations;
    }
}

