package net.pakl.rl;

import java.util.*;
import java.util.zip.*;
import java.io.*;

/** 
 * Outputs the optimal actions given a world (environment) (which includes the 
 * starting state), value function and generic policy (describing allowed 
 * actions) by picking actions which lead to the next highest-valued state.
 * Note: this does not take into account the immediate reward, only the value 
 * of the next state.
 * This supports "Patch" value functions, and so will load a new value
 * function according to the patch name derived by the state object if it
 * does not match with the currently-loaded value function.
 */
public class PolicyExtractor
{
    public static boolean DYNAMICALLY_LOAD_VALUE_FUNCTION_PATCHES = false; 
    public static final int MAXIMUM_ALLOWED_NUM_ACTIONS = 1000;
    List optimalActions = new ArrayList();
    List visitedStates = new ArrayList();
    List visitedStateValues = new ArrayList();
    int totalNumActions = 0;
        
    State forcedInitialState = null;
    
    public void forceInitialState(State s)
    {
        this.forcedInitialState = s;
    }
    
    public String extractOptimalPolicy(ActionSet naivePolicy, ValueFunction valueFunction, World trainedWorld, World testWorld, ReinforcementFunction rf, double discountFactor)
    {
        boolean reachedTerminalState = false;
        StringBuffer tempResult = new StringBuffer("");
        State currentState = forcedInitialState;
        ArrayList rewards = new ArrayList();

        System.err.println("POLICY EXTRACTOR STARTED");
        System.err.println("Naive Policy describing possible actions: " + naivePolicy.getClass());
        if (currentState == null) currentState = testWorld.getStartingState();

        while (reachedTerminalState == false)
        {
            System.out.println("CRNT: " + currentState);
            List possibleActions = naivePolicy.getAllPossibleActions(currentState);
            System.out.println("PSBL: " + possibleActions);
            Action nextAction = null;
            double maxValue = 0.0d - Double.MAX_VALUE;

            if (DYNAMICALLY_LOAD_VALUE_FUNCTION_PATCHES)
            {
                if (currentState instanceof SubdivisionIdentification)
                {
                    if ((valueFunction == null) || !valueFunction.getName().equals(currentState.getPatchName()))
                    {
                        if (valueFunction == null) 
                        { 
                            System.out.print("(No value function loaded yet.) "); 
                        }
                        else
                        {
                            if (!valueFunction.getName().equals(currentState.getPatchName())) 
                            { 
                                System.out.print("(Name of loaded value function ("+valueFunction.getName()
                                    +") does not match required patch ("+currentState.getPatchName()+")."); 
                            }
                        }
                        valueFunction = loadValueFunction("PATCHES/"+currentState.getPatchName());
                        if (valueFunction instanceof ValueFunctionHashMap)
                        {
                            ((ValueFunctionHashMap)valueFunction).setValueOfNonStoredStates(-Double.MAX_VALUE);
                        }
                    }
                }
                else
                {
                    throw new RuntimeException("Sorry, but states of type " + currentState.getClass() + " do not implement"
                            +" subdivision identification and so we don't know what patch of value function to load.");
                }
            }

            if (possibleActions.size() == 1)
            {
                nextAction = (Action) possibleActions.get(0);
            }
            else
            {
                Iterator i = possibleActions.iterator();
                while (i.hasNext())
                {
                    Action testAction = (Action) i.next();
                    State nextState = testWorld.getNewState(currentState, testAction);

                    try
                    {
                        System.out.println("TEST: " + testAction + " => " 
                                + nextState + " = " 
                                + rf.getReward(currentState,testAction,nextState) 
                                + "+" + discountFactor * valueFunction.getValue(nextState));

                        if (rf.getReward(currentState,testAction,nextState) 
                                + discountFactor * valueFunction.getValue(nextState) 
                                >= maxValue)
                        {
                            nextAction = testAction;
                            maxValue = rf.getReward(currentState,testAction,nextState) 
                                        + discountFactor * valueFunction.getValue(nextState);
                        }
                    }
                    catch (NonLearnedStateException e)
                    {
                        if (DYNAMICALLY_LOAD_VALUE_FUNCTION_PATCHES)
                        {
                            System.out.println("Ignoring non-trained state (outside of patch): " + nextState);
                        }
                        else
                        {
                            throw e;
                        }
                    }
                }
            }
            System.out.println("BEST: " + nextAction);
            totalNumActions++;
            
            if (totalNumActions > this.MAXIMUM_ALLOWED_NUM_ACTIONS)
            {
                return "PolicyExtractor failed, never reached terminal state.\nGot stuck doing:"
                        + visitedStates.get(visitedStates.size()-2) + " <-> " + visitedStates.get(visitedStates.size()-1)
                        + "\nMost recent actions: " + optimalActions.get(optimalActions.size()-2) + " and "
                        + optimalActions.get(optimalActions.size()-1);
            }
            optimalActions.add(nextAction); visitedStates.add(currentState); visitedStateValues.add(new Double(valueFunction.getValue(currentState)));
            State previousState = currentState;
            currentState = testWorld.getNewState(currentState, nextAction);
            double reward = rf.getReward(previousState, nextAction, currentState); rewards.add(new Double(reward));
            System.out.println("TOOK: " + nextAction);
            System.out.println("--->: " + currentState + "\n");
            if (testWorld.isTerminalState(currentState))
            {
                reachedTerminalState = true;
                System.out.println("PolicyExtractor: " +currentState+" is a terminal state, stopping.");
                visitedStates.add(currentState);
                visitedStateValues.add(new Double(valueFunction.getValue(currentState)));
            }
        }
        System.err.println("POLICY EXTRACTOR ENDED.\n");
        tempResult.append("\n\nVisited States:\n" + displayList(visitedStates, "visitedStates") + "\n");
        tempResult.append("\n\nVisited Values:\n" + displayList(visitedStateValues, "visitedValues") + "\n");
        tempResult.append("\n\nOptimal Actions:\n" + displayList(optimalActions, "optimalActions") + "\n");
        tempResult.append("Total Rewards:" + rewards + " Sum: " + sumOf(rewards)+"\n\n");
        return tempResult.toString();
    }

    private String showVector(double [] x)
    {
        StringBuffer result = new StringBuffer();
        result.append("[");
        for (int i = 0; i < x.length; i++)
        {
            result.append(x[i] + " ");
        }
        result.append("]");
        return result.toString();
    }
        
    private String displayList(List list, String prefix)
    {
        String result = "";
        for (int i = 0; i < list.size(); i++)
        {
            result = result + prefix + " " + i + " - " + list.get(i);
            result = result + "\n";
        }
        return result;
    }
    
    
    private String displayListClean(List list)
    {
        String result = "";
        for (int i = 0; i < list.size(); i++)
        {
            result = result + list.get(i);
            result = result + "\n";
        }
        result = result.replaceAll("\\[", "");
        result = result.replaceAll("\\]", "");
        result = result.replaceAll("state", "");
        return result;
    }
        
	private double sumOf(List ofDoubles)
	{
		Iterator i = ofDoubles.iterator();
		double sum = 0;
		while (i.hasNext())
		{
			sum = sum + ((Double) i.next()).doubleValue();
		}
		return sum;
	}
        
    public java.util.List getOptimalActions()
    {
        return optimalActions;
    }
    
    public void setOptimalActions(java.util.List optimalActions)
    {
        this.optimalActions = optimalActions;
    }
    
    private ValueFunction loadValueFunction(String filename)
    {
        System.out.println("Loading pre-existing value function "+filename+".");
        try
        {
            ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(new FileInputStream(filename+".obj.gz")));
            ValueFunction vf = null;
            vf = (ValueFunction) in.readObject();
            System.out.println("Value function of type "+vf.getClass()+" loaded from " + filename);
            if (vf instanceof ValueFunctionResidualAlgorithmLinear)
            {
                System.out.println("ValueFunction is linear network!");

            }
            if (vf instanceof ValueFunctionResidualAlgorithmPerceptron)
            {
                System.out.println("ValueFunction is perceptron network!");
            }
            return vf;
        }
        catch (Exception e) 
        { 
            System.out.println("Reading value function from file "+filename+" failed.\n");
            e.printStackTrace();
            throw new RuntimeException("Value Function loading failed.");
        }
    }        

}
