package net.pakl.rl;

import java.util.*;


/**
 * This class represents an agent, which is capable of iterating
 * over all of the actions from its given state given a {@link ActionSet}.
 * An agent maintains its own state and can be assigned a fixed policy.
 */
public class Agent
{
    private static final String VERSION = "v1.57 January 27, 2006";

        
    public static boolean UPDATE_ONLY_WHEN_GREEDY = true;

    protected State state;				// Current state of the agent

    protected ActionSet policy;
    protected World world;
    protected ReinforcementFunction reinforcementFunction;
    protected String name;
    protected final long DEFAULT_RANDOM_NUMBER_SEED = 12345;
    protected Random random = new Random(DEFAULT_RANDOM_NUMBER_SEED);
    
    
    /** The higher this is, the more the future (distant reward) becomes important in making
     * a decision -- but beware of infinite loops and explosions of value. */
    public double discountFactor = 1.0d;
    
    /** Greediness (used during TrajectorySampling) reflects the probability of 
     * taking the max action during learning;  so 1-greed is
     * the probability of taking a random action from allowable possible actions. */
    public double greed = 1.0d;
    
    /** This controls the rate at which states are changed to contain the new
     * value obtained by the learning algorithm -- for a stochastic world set this
     * to something less than 1 so you don't forget good experiences from a state. */
    public double epsilon = 1.0d;
    
    double maximumDelta;
    double averageDelta;
    double totalDelta;
    
    protected int callNumber = 0;
    
    
    public Agent()
    {
        this("agent");
    }
    
    public Agent(String newName)
    {
        this.name = newName;
        
        System.out.println("-----------------------------------------------------------------------------------------------");
        System.out.println("  RL Agent VERSION " + VERSION + " patryk@cnbc.cmu.edu");
        System.out.println("-----------------------------------------------------------------------------------------------");
        System.out.println("UPDATE_ONLY_WHEN_GREEDY = " + UPDATE_ONLY_WHEN_GREEDY);
    }
    
    public void setWorld(World newWorld)
    {
        this.world = newWorld;
    }
    
    public void initializeRandomSeed(long newSeed)
    {
        this.random = new Random(newSeed);
    }
      
    // If the residual gradient algorithm is combined with stochastic worlds,
    // the following code may need to be used
    //State state3 = world.getNewState(newState, bestNextAction);
    //double value3 = valueFunction.getValue(state3);
    //double reinforcement2_3 = reinforcementFunction.getReward(newState, bestNextAction, state3);
    //double newDelta2 = reinforcement2_3 + (discountFactor * value3);
    //double newValue2 = one_minus_epsilon * valueFunction.getValue(newState) + epsilon * newDelta2;
    
    protected void performValueIterationUpdateOnState(ValueFunction newValueFunction, ValueFunction valueFunction, State currentState)
    {
        List allPossibleActions = policy.getAllPossibleActions(currentState);
        
        
        if (allPossibleActions.size() > 0)
        {
            // ---------------------------------------------------------------------------
            // Step 1: Find action leading to state with highest value.
            // ---------------------------------------------------------------------------
            Action maxAction = (Action) allPossibleActions.get(0);
            State newState = world.getNewState(currentState, maxAction);
            double reinforcement = reinforcementFunction.getReward(currentState, maxAction, newState);
            double nextStateValue = valueFunction.getValue(newState);
            double maximizeMe = reinforcement + discountFactor * nextStateValue;
            
            // Pick best action.
            for (int i = 1; i < allPossibleActions.size(); i++)
            {
                Action testAction = (Action) allPossibleActions.get(i);
                State testNewState = world.getNewState(currentState, testAction);
                double testReinforcement = reinforcementFunction.getReward(currentState, testAction, testNewState);
                double testNextStateValue = valueFunction.getValue(testNewState);
                double testMaximizeMe = testReinforcement + discountFactor * testNextStateValue;
                if (testMaximizeMe > maximizeMe)
                {
                    maxAction = testAction;
                    reinforcement = testReinforcement;
                    nextStateValue = testNextStateValue;
                    maximizeMe = testMaximizeMe;
                    newState = testNewState;
                }
            }
            
            
            // ---------------------------------------------------------------------------
            // Step 2: (Incrementally) update value of the current state to:
            //      R + df * V(t+1) - V(t)
            // ---------------------------------------------------------------------------
            double oldValue = valueFunction.getValue(currentState);
            double newDelta = reinforcement + (discountFactor * nextStateValue) - oldValue;
            
            if ((newValueFunction instanceof ValueFunctionResidualAlgorithmLinear) 
                || (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron))
            {
                double newValue = reinforcement + (discountFactor * nextStateValue);
                Action bestNextAction = getBestActionForValueFrom(newState, valueFunction, policy);
                if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear)
                    ((ValueFunctionResidualAlgorithmLinear)newValueFunction).setValue(currentState, newState, newValue, discountFactor);
                if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron)
                    ((ValueFunctionResidualAlgorithmPerceptron)newValueFunction).setValue(currentState, newState, newValue, discountFactor);                            
            }
            else
            {
                newValueFunction.setValue(currentState, oldValue + epsilon * newDelta);
            }
            
            
            totalDelta = totalDelta + Math.abs(newDelta);
        }
        else
        {
            throw new RuntimeException("\nArrived at a state ("+currentState+") from which no action exists. "
                    +"This is not allowed by this framework.\nE.g., from a terminal states, "
                    +"you should specify actions returning to the terminal state.");
        }
        
    }
    
    
    public static int printStartStateValueEvery = 10000;
    static int printStartStateValueCounter = 0;
    
    public void experience(ValueFunction newValueFunction, ValueFunction valueFunction, State startState, State nextState, double reinforcement)
    {
        double newDelta = reinforcement + (discountFactor * valueFunction.getValue(nextState)) - valueFunction.getValue(startState);        
        newValueFunction.setValue(startState, valueFunction.getValue(startState) + epsilon * newDelta);
    }
    
    public void performValueIterationTrajectorySample(ValueFunction newValueFunction, ValueFunction valueFunction)
    {
        int pathLength = 0;
        this.totalDelta = 0;
        State currentState = world.getStartingState();
        printStartStateValueCounter++;
        if (printStartStateValueCounter >= printStartStateValueEvery) 
        {
            System.out.print(" V(startState) = " + valueFunction.getValue(currentState));
        }
        boolean wasGreedy = true;
        while (!world.isTerminalState(currentState))
        {
            pathLength++;
            List allPossibleActions = policy.getAllPossibleActions(currentState);
            
            if (allPossibleActions.size() > 0)
            {
                // ---------------------------------------------------------------------------
                // Step 1: Find action leading to state with highest value.
                // ---------------------------------------------------------------------------
                Action maxAction = (Action) allPossibleActions.get(0);
                State newState = world.getNewState(currentState, maxAction);
                double reinforcement = reinforcementFunction.getReward(currentState, maxAction, newState);
                double nextStateValue = valueFunction.getValue(newState);
                double maximizeMe = reinforcement + discountFactor * nextStateValue;
                
                if (random.nextDouble() < greed)
                {
                    wasGreedy = true;
                    // Pick best action.
                    for (int i = 1; i < allPossibleActions.size(); i++)
                    {
                        Action testAction = (Action) allPossibleActions.get(i);
                        State testNewState = world.getNewState(currentState, testAction);
                        double testReinforcement = reinforcementFunction.getReward(currentState, testAction, testNewState);
                        double testNextStateValue = valueFunction.getValue(testNewState);
                        double testMaximizeMe = testReinforcement + discountFactor * testNextStateValue;
                        if (testMaximizeMe > maximizeMe)
                        {
                            maxAction = testAction;
                            reinforcement = testReinforcement;
                            nextStateValue = testNextStateValue;
                            maximizeMe = testMaximizeMe;
                            newState = testNewState;
                        }
                    }
                }
                else
                {
                    // Pick random action.
                    wasGreedy = false;
                    int randomActionNumber = (int) (random.nextDouble() * allPossibleActions.size());
                    maxAction = (Action) allPossibleActions.get(randomActionNumber);
                    newState = world.getNewState(currentState, maxAction);
                    reinforcement = reinforcementFunction.getReward(currentState, maxAction, newState);
                    nextStateValue = valueFunction.getValue(newState);
                }
                
                // ---------------------------------------------------------------------------
                // Step 2: (Incrementally) update value of the current state to:
                //      R + df * V(t+1) - V(t)
                // ---------------------------------------------------------------------------
                
                if (UPDATE_ONLY_WHEN_GREEDY && !wasGreedy)
                {
                    currentState = newState;
                    continue;
                }
                
                double oldValue = valueFunction.getValue(currentState);
                double newDelta = reinforcement + (discountFactor * nextStateValue) - oldValue;
                
                if ((newValueFunction instanceof ValueFunctionResidualAlgorithmLinear) || (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron))
                {
                    double newValue = reinforcement + (discountFactor * nextStateValue);
                    Action bestNextAction = getBestActionForValueFrom(newState, valueFunction, policy);
                    if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear)
                        ((ValueFunctionResidualAlgorithmLinear)newValueFunction).setValue(currentState, newState, newValue, discountFactor);
                    if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron)
                    {
                        // ((ValueFunctionResidualAlgorithmPerceptron)newValueFunction).setValue(currentState, newState, newValue, discountFactor);                            
                        ((ValueFunctionResidualAlgorithmPerceptron)newValueFunction).setValue(currentState, newState, oldValue+epsilon*newDelta, discountFactor);                            
                    }
                }
                else
                {
                    newValueFunction.setValue(currentState, oldValue + epsilon * newDelta);
                }
                
                totalDelta = totalDelta + Math.abs(newDelta);
                
                currentState = newState;
            }
            else
            {
                throw new RuntimeException("\nArrived at a state ("+currentState+") from which no action exists. "
                        +"This is not allowed by this framework.\nE.g., from a terminal states, "
                        +"you should specify actions returning to the terminal state.");
            }
            
        }
        
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear)
        {
            ValueFunctionResidualAlgorithmLinear v = ((ValueFunctionResidualAlgorithmLinear) newValueFunction);
            v.storeWeightChangesIfNonIncremental();
            System.out.print(" deltaW " + v.getTotalIncrementalWeightChangeAndReset()+" ");
        }
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron)
        {
            ValueFunctionResidualAlgorithmPerceptron v = ((ValueFunctionResidualAlgorithmPerceptron) newValueFunction);
            v.storeWeightChangesIfNonIncremental();
        }
        
        if (printStartStateValueCounter >= printStartStateValueEvery) 
        {
            System.out.println(" Trajectory length = "+pathLength + " total delta = " + totalDelta);
            printStartStateValueCounter = 0;
        }
    }
    
    
    
    
    /** This is the MAIN learning function in the Agent. */
    public ValueFunction performValueIteration(ValueFunction newValueFunction, ValueFunction valueFunction)
    {
        // -------------------------------------------------------------------
        // These variables are for Diagnostics/Informational purposes only
        State maximumDeltaState = null; Action maximumDeltaAction = null;
        int stateNumber = 0;
        long lastUpdatePrinted = System.currentTimeMillis();
        int statesProcessedThisCall = 0;
        callNumber++;  totalDelta = 0.0d;
        maximumDelta = 0.0d;
        int fractionSize = 100;
        int fraction = this.world.getNumberOfStates() / fractionSize;
        if (fraction == 0)
        { fraction = 1; }
        // -------------------------------------------------------------------
        
        Iterator stateIterator = null;
        // Iterating over KEYS can be much faster than generating states, but you have to be careful
        // that there are no identical states, or they may be undertrained.
        //
        // The following comented-out code iterates over the keys if the value function has
        // keys.
        //
//        if (valueFunction.getClass().equals(ValueFunctionHashMap.class))
//        {
//            try
//            {
//                stateIterator = ((ValueFunctionHashMap)valueFunction).getKeySetIterator();
//            }
//            catch (Exception e)
//            {
//                stateIterator = this.world.stateIterator();
//            }
//        }
//        else
//        {
            stateIterator = this.world.stateIterator();
//        }
        
        if (!stateIterator.hasNext())
        {
            throw new RuntimeException("Value Iteration has no states to iterate through; if using regular value iteration make sure safe mode is enabled (valueFunctionSafety=true) because the agent iterates over the state keys in the value function.");
        }
        
        while (stateIterator.hasNext())
        {
            State currentState = (State) stateIterator.next();
            try
            {
                statesProcessedThisCall++;
                performValueIterationUpdateOnState(newValueFunction, valueFunction, currentState);

                if (System.currentTimeMillis() - lastUpdatePrinted > 10000)
                {
                    System.out.println(statesProcessedThisCall + " states processed ("+(100 * statesProcessedThisCall / world.getNumberOfStates())+"%)");
                    lastUpdatePrinted = System.currentTimeMillis();
                }
            }
            catch (RuntimeException e)
            {
                System.err.println("The error occured while exploring from state " + currentState);
                throw e;
            }
            stateNumber++;
        }
        
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmLinear)
        {
            ValueFunctionResidualAlgorithmLinear v = ((ValueFunctionResidualAlgorithmLinear) newValueFunction);
            v.storeWeightChangesIfNonIncremental();
            System.out.println("Total weight change if incremental: " + v.getTotalIncrementalWeightChangeAndReset());
        }
        
        if (newValueFunction instanceof ValueFunctionResidualAlgorithmPerceptron)
        {
            ValueFunctionResidualAlgorithmPerceptron v = ((ValueFunctionResidualAlgorithmPerceptron) newValueFunction);
            v.storeWeightChangesIfNonIncremental();
        }
        
        averageDelta = totalDelta / stateNumber;
        System.out.println("\navgDelta = " + averageDelta + " maxDelta was " + maximumDelta + " totalDelta was " + totalDelta + " for " + stateNumber + " iterated states."); 
        
        return newValueFunction;
    }
    
    
    /** This is a function which should probably be called more often to reduce
     * duplicated code.
     */
    public Action getBestActionForValueFrom(State s, ValueFunction vf, ActionSet p)
    {
        double maxValue = Double.NEGATIVE_INFINITY;
        Action result = null;
        Iterator i = p.getAllPossibleActions(s).iterator();
        while (i.hasNext())
        {
            Action a = (Action) i.next();
            if (result == null)
            {
                result = a;
            }
            
            State newState = world.getNewState(s, a);
            double newStateValue = vf.getValue(newState);
            double reward = reinforcementFunction.getReward(s, a, newState);
            double maximizeMe = reward + discountFactor * newStateValue;
            if (maximizeMe > maxValue)
            {
                maxValue = newStateValue;
                result = a;
            }
        }
        return result;
    }
    

    
    public double getDiscountFactor()
    {
        return discountFactor;
    }
    
    public void setDiscountFactor(double discountFactor)
    {
        this.discountFactor = discountFactor;
    }
    
    public void setEpsilon(double epsilon)
    {
        this.epsilon = epsilon;
    }
        
    public double getMaximumDelta()
    {
        return maximumDelta;
    }
    
    public double getAverageDelta()
    {
        return averageDelta;
    }
    
    public double getTotalDelta()
    {
        return totalDelta;
    }
    
    public void setState(State newState)
    {
        this.state = newState;
        System.err.println("Agent " + getName() + ": \"Acknowledged, I have been moved to state '"+newState.toString()+"'.\"");
    }
    
    public void setPolicy(ActionSet newPolicy)
    {
        this.policy = newPolicy;
        System.err.println("Agent " + getName() + ": \"Acknowledged, I have received a Policy.\"");
    }
    
    public void setReinforcementFunction(ReinforcementFunction newReinforcementFunction)
    {
        this.reinforcementFunction = newReinforcementFunction;
    }
        
    public String getName()
    {
        return this.name;
    }
    
    private double [][] multiply(double a, double v[][])
    {
        double [][] result = new double[v.length][v[0].length];
        for (int i = 0; i < v.length; i++)
            for (int j = 0; j < v[i].length; j++)
                result[i][j] = a * v[i][j];
        
        return result;
    }
    
    private double [][] add(double a[][], double b[][])
    {
        double [][] result = new double[a.length][a[0].length];
        for (int i = 0; i < a.length; i++)
            for (int j = 0; j < a[i].length; j++)
                result[i][j] = a[i][j] + b[i][j];
        return result;
    }
    
    
    public void setGreed(double greed)
    {
        this.greed = greed;
    }
}
