/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.rl;

import java.util.Iterator;
import java.util.List;
import java.util.Random;
import net.pakl.rl.Action;
import net.pakl.rl.ActionSet;
import net.pakl.rl.ReinforcementFunction;
import net.pakl.rl.State;
import net.pakl.rl.ValueFunction;
import net.pakl.rl.ValueFunctionResidualAlgorithmLinear;
import net.pakl.rl.ValueFunctionResidualAlgorithmPerceptron;
import net.pakl.rl.World;

public class Agent {
    private static final String VERSION = "v1.57 January 27, 2006";
    public static boolean UPDATE_ONLY_WHEN_GREEDY = true;
    protected State state;
    protected ActionSet policy;
    protected World world;
    protected ReinforcementFunction reinforcementFunction;
    protected String name;
    protected final long DEFAULT_RANDOM_NUMBER_SEED = 12345L;
    protected Random random = new Random(12345L);
    public double discountFactor = 1.0;
    public double greed = 1.0;
    public double epsilon = 1.0;
    double maximumDelta;
    double averageDelta;
    double totalDelta;
    protected int callNumber = 0;
    public static int printStartStateValueEvery = 10000;
    static int printStartStateValueCounter = 0;

    public Agent() {
        this("agent");
    }

    public Agent(String newName) {
        this.name = newName;
        System.out.println("-----------------------------------------------------------------------------------------------");
        System.out.println("  RL Agent VERSION v1.57 January 27, 2006 patryk@cnbc.cmu.edu");
        System.out.println("-----------------------------------------------------------------------------------------------");
        System.out.println("UPDATE_ONLY_WHEN_GREEDY = " + UPDATE_ONLY_WHEN_GREEDY);
    }

    public void setWorld(World newWorld) {
        this.world = newWorld;
    }

    public void initializeRandomSeed(long newSeed) {
        this.random = new Random(newSeed);
    }

    protected void performValueIterationUpdateOnState(ValueFunction newValueFunction, ValueFunction valueFunction, State currentState) {
        List allPossibleActions = this.policy.getAllPossibleActions(currentState);
        if (allPossibleActions.size() > 0) {
            Action maxAction = (Action)allPossibleActions.get(0);
            State newState = this.world.getNewState(currentState, maxAction);
            double reinforcement = this.reinforcementFunction.getReward(currentState, maxAction, newState);
            double nextStateValue = valueFunction.getValue(newState);
            double maximizeMe = reinforcement + this.discountFactor * nextStateValue;
            for (int i = 1; i < allPossibleActions.size(); ++i) {
                double testNextStateValue;
                State testNewState;
                Action testAction = (Action)allPossibleActions.get(i);
                double testReinforcement = this.reinforcementFunction.getReward(currentState, testAction, testNewState = this.world.getNewState(currentState, testAction));
                double testMaximizeMe = testReinforcement + this.discountFactor * (testNextStateValue = valueFunction.getValue(testNewState));
                if (!(testMaximizeMe > maximizeMe)) continue;
                maxAction = testAction;
                reinforcement = testReinforcement;
                nextStateValue = testNextStateValue;
                maximizeMe = testMaximizeMe;
                newState = testNewState;
            }
            double oldValue = valueFunction.getValue(currentState);
            double newDelta = reinforcement + this.discountFactor * nextStateValue - oldValue;
            if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear || newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron) {
                double newValue = reinforcement + this.discountFactor * nextStateValue;
                Action bestNextAction = this.getBestActionForValueFrom(newState, valueFunction, this.policy);
                if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear) {
                    ((ValueFunctionResidualAlgorithmLinear)newValueFunction).setValue(currentState, newState, newValue, this.discountFactor);
                }
                if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron) {
                    ((ValueFunctionResidualAlgorithmPerceptron)newValueFunction).setValue(currentState, newState, newValue, this.discountFactor);
                }
            } else {
                newValueFunction.setValue(currentState, oldValue + this.epsilon * newDelta);
            }
            this.totalDelta += Math.abs(newDelta);
        } else {
            throw new RuntimeException("\nArrived at a state (" + currentState + ") from which no action exists. " + "This is not allowed by this framework.\nE.g., from a terminal states, " + "you should specify actions returning to the terminal state.");
        }
    }

    public void experience(ValueFunction newValueFunction, ValueFunction valueFunction, State startState, State nextState, double reinforcement) {
        double newDelta = reinforcement + this.discountFactor * valueFunction.getValue(nextState) - valueFunction.getValue(startState);
        newValueFunction.setValue(startState, valueFunction.getValue(startState) + this.epsilon * newDelta);
    }

    public void performValueIterationTrajectorySample(ValueFunction newValueFunction, ValueFunction valueFunction) {
        ValueFunction v;
        int pathLength = 0;
        this.totalDelta = 0.0;
        State currentState = this.world.getStartingState();
        if (++printStartStateValueCounter >= printStartStateValueEvery) {
            System.out.print(" V(startState) = " + valueFunction.getValue(currentState));
        }
        boolean wasGreedy = true;
        while (!this.world.isTerminalState(currentState)) {
            ++pathLength;
            List allPossibleActions = this.policy.getAllPossibleActions(currentState);
            if (allPossibleActions.size() > 0) {
                Action maxAction = (Action)allPossibleActions.get(0);
                State newState = this.world.getNewState(currentState, maxAction);
                double reinforcement = this.reinforcementFunction.getReward(currentState, maxAction, newState);
                double nextStateValue = valueFunction.getValue(newState);
                double maximizeMe = reinforcement + this.discountFactor * nextStateValue;
                if (this.random.nextDouble() < this.greed) {
                    wasGreedy = true;
                    for (int i = 1; i < allPossibleActions.size(); ++i) {
                        double testNextStateValue;
                        State testNewState;
                        Action testAction = (Action)allPossibleActions.get(i);
                        double testReinforcement = this.reinforcementFunction.getReward(currentState, testAction, testNewState = this.world.getNewState(currentState, testAction));
                        double testMaximizeMe = testReinforcement + this.discountFactor * (testNextStateValue = valueFunction.getValue(testNewState));
                        if (!(testMaximizeMe > maximizeMe)) continue;
                        maxAction = testAction;
                        reinforcement = testReinforcement;
                        nextStateValue = testNextStateValue;
                        maximizeMe = testMaximizeMe;
                        newState = testNewState;
                    }
                } else {
                    wasGreedy = false;
                    int randomActionNumber = (int)(this.random.nextDouble() * (double)allPossibleActions.size());
                    maxAction = (Action)allPossibleActions.get(randomActionNumber);
                    newState = this.world.getNewState(currentState, maxAction);
                    reinforcement = this.reinforcementFunction.getReward(currentState, maxAction, newState);
                    nextStateValue = valueFunction.getValue(newState);
                }
                if (UPDATE_ONLY_WHEN_GREEDY && !wasGreedy) {
                    currentState = newState;
                    continue;
                }
                double oldValue = valueFunction.getValue(currentState);
                double newDelta = reinforcement + this.discountFactor * nextStateValue - oldValue;
                if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear || newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron) {
                    double newValue = reinforcement + this.discountFactor * nextStateValue;
                    Action bestNextAction = this.getBestActionForValueFrom(newState, valueFunction, this.policy);
                    if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear) {
                        ((ValueFunctionResidualAlgorithmLinear)newValueFunction).setValue(currentState, newState, newValue, this.discountFactor);
                    }
                    if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron) {
                        ((ValueFunctionResidualAlgorithmPerceptron)newValueFunction).setValue(currentState, newState, oldValue + this.epsilon * newDelta, this.discountFactor);
                    }
                } else {
                    newValueFunction.setValue(currentState, oldValue + this.epsilon * newDelta);
                }
                this.totalDelta += Math.abs(newDelta);
                currentState = newState;
                continue;
            }
            throw new RuntimeException("\nArrived at a state (" + currentState + ") from which no action exists. " + "This is not allowed by this framework.\nE.g., from a terminal states, " + "you should specify actions returning to the terminal state.");
        }
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear) {
            v = (ValueFunctionResidualAlgorithmLinear)newValueFunction;
            ((ValueFunctionResidualAlgorithmLinear)v).storeWeightChangesIfNonIncremental();
            System.out.print(" deltaW " + ((ValueFunctionResidualAlgorithmLinear)v).getTotalIncrementalWeightChangeAndReset() + " ");
        }
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron) {
            v = (ValueFunctionResidualAlgorithmPerceptron)newValueFunction;
            ((ValueFunctionResidualAlgorithmPerceptron)v).storeWeightChangesIfNonIncremental();
        }
        if (printStartStateValueCounter >= printStartStateValueEvery) {
            System.out.println(" Trajectory length = " + pathLength + " total delta = " + this.totalDelta);
            printStartStateValueCounter = 0;
        }
    }

    public ValueFunction performValueIteration(ValueFunction newValueFunction, ValueFunction valueFunction) {
        ValueFunction v;
        Object maximumDeltaState = null;
        Object maximumDeltaAction = null;
        int stateNumber = 0;
        long lastUpdatePrinted = System.currentTimeMillis();
        int statesProcessedThisCall = 0;
        ++this.callNumber;
        this.totalDelta = 0.0;
        this.maximumDelta = 0.0;
        int fractionSize = 100;
        int fraction = this.world.getNumberOfStates() / fractionSize;
        if (fraction == 0) {
            fraction = 1;
        }
        Iterator stateIterator = null;
        stateIterator = this.world.stateIterator();
        if (!stateIterator.hasNext()) {
            throw new RuntimeException("Value Iteration has no states to iterate through; if using regular value iteration make sure safe mode is enabled (valueFunctionSafety=true) because the agent iterates over the state keys in the value function.");
        }
        while (stateIterator.hasNext()) {
            State currentState = (State)stateIterator.next();
            try {
                ++statesProcessedThisCall;
                this.performValueIterationUpdateOnState(newValueFunction, valueFunction, currentState);
                if (System.currentTimeMillis() - lastUpdatePrinted > 10000L) {
                    System.out.println(statesProcessedThisCall + " states processed (" + 100 * statesProcessedThisCall / this.world.getNumberOfStates() + "%)");
                    lastUpdatePrinted = System.currentTimeMillis();
                }
            }
            catch (RuntimeException e) {
                System.err.println("The error occured while exploring from state " + currentState);
                throw e;
            }
            ++stateNumber;
        }
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear) {
            v = (ValueFunctionResidualAlgorithmLinear)newValueFunction;
            ((ValueFunctionResidualAlgorithmLinear)v).storeWeightChangesIfNonIncremental();
            System.out.println("Total weight change if incremental: " + ((ValueFunctionResidualAlgorithmLinear)v).getTotalIncrementalWeightChangeAndReset());
        }
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron) {
            v = (ValueFunctionResidualAlgorithmPerceptron)newValueFunction;
            ((ValueFunctionResidualAlgorithmPerceptron)v).storeWeightChangesIfNonIncremental();
        }
        this.averageDelta = this.totalDelta / (double)stateNumber;
        System.out.println("\navgDelta = " + this.averageDelta + " maxDelta was " + this.maximumDelta + " totalDelta was " + this.totalDelta + " for " + stateNumber + " iterated states.");
        return newValueFunction;
    }

    public Action getBestActionForValueFrom(State s, ValueFunction vf, ActionSet p) {
        double maxValue = Double.NEGATIVE_INFINITY;
        Action result = null;
        for (Action a : p.getAllPossibleActions(s)) {
            if (result == null) {
                result = a;
            }
            State newState = this.world.getNewState(s, a);
            double newStateValue = vf.getValue(newState);
            double reward = this.reinforcementFunction.getReward(s, a, newState);
            double maximizeMe = reward + this.discountFactor * newStateValue;
            if (!(maximizeMe > maxValue)) continue;
            maxValue = newStateValue;
            result = a;
        }
        return result;
    }

    public double getDiscountFactor() {
        return this.discountFactor;
    }

    public void setDiscountFactor(double discountFactor) {
        this.discountFactor = discountFactor;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public double getMaximumDelta() {
        return this.maximumDelta;
    }

    public double getAverageDelta() {
        return this.averageDelta;
    }

    public double getTotalDelta() {
        return this.totalDelta;
    }

    public void setState(State newState) {
        this.state = newState;
        System.err.println("Agent " + this.getName() + ": \"Acknowledged, I have been moved to state '" + newState.toString() + "'.\"");
    }

    public void setPolicy(ActionSet newPolicy) {
        this.policy = newPolicy;
        System.err.println("Agent " + this.getName() + ": \"Acknowledged, I have received a Policy.\"");
    }

    public void setReinforcementFunction(ReinforcementFunction newReinforcementFunction) {
        this.reinforcementFunction = newReinforcementFunction;
    }

    public String getName() {
        return this.name;
    }

    private double[][] multiply(double a, double[][] v) {
        double[][] result = new double[v.length][v[0].length];
        for (int i = 0; i < v.length; ++i) {
            for (int j = 0; j < v[i].length; ++j) {
                result[i][j] = a * v[i][j];
            }
        }
        return result;
    }

    private double[][] add(double[][] a, double[][] b) {
        double[][] result = new double[a.length][a[0].length];
        for (int i = 0; i < a.length; ++i) {
            for (int j = 0; j < a[i].length; ++j) {
                result[i][j] = a[i][j] + b[i][j];
            }
        }
        return result;
    }

    public void setGreed(double greed) {
        this.greed = greed;
    }
}

