package net.pakl.rl;

import java.util.*;
import java.io.*;

/**
 * A ValueFunction maps states, which are positions
 * in a {@link World}, to values, and may be replaced
 * with a neural network.   When the value function
 * is optimal then the optimal policy can be extracted
 * from it.
 */
public class ValueFunctionHashMap implements ValueFunction, Serializable
{
    static final long serialVersionUID = 5183569982779176111L;
    
    HashMap statesToValues = new HashMap();
;
    private double valueOfTerminalStates = 0.0d;
    HashSet worldTerminalStates;
    private double timer = 0;
    boolean SAFE_MODE = true;  // Leave true for safety.
    boolean ZIPPED = false;
    boolean DATABASE = false;
    
    /** When true, the value function will return the valueOfOutOfBoundsStates if the state requested is out of bounds. */
    double valueOfOutOfBoundsStates = 0;        
    double valueOfNonStoredStates = 0;
    private static final String SAVED_IDENTIFICATION = "ValueFunction Saved Object";
    World world = null;
    private String name = "valuefunction";
    
    public void setValueOfNonStoredStates(double x)
    {
        this.valueOfNonStoredStates = x;
    }
    
    public void setName(String name)
    {
        this.name = name;
    }
    public String getName()
    {
        return name;
    }
    
    public void setSafeMode(boolean b)
    {
        this.SAFE_MODE = b;
    }
    /**
     * While the initial ValueFunction is arbitrary,
     * it must be based on a World's states. The initial
     * random values should probably be a function of the
     * number of states in the world, but I don't yet know
     * what that function should be. */
    public ValueFunctionHashMap(World sourceWorld)
    {
        init(sourceWorld);
    }
    
    public ValueFunctionHashMap(World sourceWorld, boolean SAFE_MODE)
    {
        this.SAFE_MODE = SAFE_MODE;
        init(sourceWorld);
    }
    
    
    public int getSize()
    {
        return this.statesToValues.keySet().size();
    }
    
    public void setWorld(World w)
    {
        init(w);
    }
    
    public void init(World sourceWorld)
    {
        world = sourceWorld;
        
        int counter = 0;

        if (SAFE_MODE)
        {
            try
            {
                System.out.println("Setting value of each state to zero.");
                Iterator statesIterator = sourceWorld.stateIterator();
                while (statesIterator.hasNext())
                {
                    State thisState = (State) statesIterator.next();
                    this.statesToValues.put(thisState, new Double(0.0d));
                    counter++;
                }
                System.out.println("Set "+counter+" items to zero.");
            }
            catch (Exception e)
            {
                System.err.println("Had problem initializing all states of the value function.\nIf you did not define a complete list of states\nin your World, you should set SAFE_MODE to\nfalse on the value function so that it will return a\ndefault value for unstored states.");
                System.err.println("The specific error was: " + e.getMessage());
                throw new RuntimeException("Simulation terminated because value function not initialized.");
            }
        }
        else
        {
            System.out.println("ValueFunctionHashMap not in safe mode: will not throw exception if queried for non-existant state.");
        }
    }


    public void setValueOfOutOfBoundsStates(double newValueOfOutOfBoundsStates)
    {
        valueOfOutOfBoundsStates = newValueOfOutOfBoundsStates;
    }
    
    public void setAllowExpansionOfStateBounds(boolean trueOrFalse)
    {
        throw new RuntimeException("Not implemented.");
    }
    
    private HashMap convertListToHashMap(List list)
    {
        HashMap result = new HashMap(list.size());
        Iterator i = list.iterator();
        while (i.hasNext())
        {
            result.put(i.next(), "");
        }
        return result;
    }
    
    public ValueFunctionHashMap()
    {
        throw new RuntimeException("You should initialize the value function with the world so that it knows what the terminal states are.");
        //worldTerminalStates = new HashSet();
        //statesToValues = new HashMap();
    }
    
    public Iterator getKeySetIterator()
    {
        return statesToValues.keySet().iterator();
    }

    public Set<State> getKeySet()
    {
        return statesToValues.keySet();
    }
    
    public void clear()
    {
        statesToValues.clear();
    }
    public int size()
    {
        return statesToValues.size();
    }
    /**
     * The value of a state is defined as the sum of the terinforcements received when starting
     * in that state and following some fixed policy to a terminal state;  the optimal policy would
     * map states to actions that maximizes the sum of reinforcements received when starting in an
     * arbitrary state and performing actions until the terminal state is reached
     */
    public double getValue(State state)
    {
        if (world instanceof IsWinnable)
        {
            IsWinnable game = (IsWinnable) world;
            if (game.isWinState(state)) return 1;
            if (game.isLoseState(state)) return -1;
            if (game.isDrawState(state)) return 1;
        }
        
        if (world.isTerminalState(state)) return this.valueOfTerminalStates;        
        
        Double value = (Double) statesToValues.get(state);
        
        if (value == null)
        {
            if (SAFE_MODE)
            {
                throw new NonLearnedStateException("\n\nQueried for non-stored state, don't have generalization ability.\nAre you testing on a state which was never trained? Check your test corpus.\nUnknown state was: " + state);
            }
            else
            {
                // return 0.0d;
                return valueOfNonStoredStates;
            }
        }
        else
        {
            return value.doubleValue();
        }
    }
    
    public double getTimer()
    {
        return this.timer;
    }
    public void resetTimer()
    {
        this.timer = 0;
    }
    public void addTimer(double time)
    {
        this.timer += time;
    }
    
    public void setValue(State state, double newValue)
    {
        statesToValues.put(state, new Double(newValue));
    }
    
    public void saveTo(String filename) throws Exception
    {
        ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(filename));
        out.writeObject(SAVED_IDENTIFICATION);
        out.writeObject(this.statesToValues);
        out.close();
    }
    
    public void loadFrom(String filename) throws Exception
    {
        ObjectInputStream in = new ObjectInputStream(new FileInputStream(filename));
        String s = (String) in.readObject();
        if (!s.equals(SAVED_IDENTIFICATION))
        {
            throw new RuntimeException("Error: ValueFunction.loadFrom read from a file that contained something other than a ValueFunction -- a "+s);
        }
        this.statesToValues = (HashMap) in.readObject();
        in.close();
    }
    
    public String toText()
    {
        int j = 0;
        java.text.DecimalFormat decimalFormat = new java.text.DecimalFormat("#0.00");
        StringBuffer result = new StringBuffer("");
        Iterator i = statesToValues.keySet().iterator();
        while (i.hasNext())
        {
            State s = (State) i.next();
            double value = ((Double)statesToValues.get(s)).doubleValue();
            System.out.print("vf("+s+") " + s.hashCode() + " = " + decimalFormat.format(value) + "\n");
        }
        
        return result.toString();
    }
    
    public void setValueOfTerminalStates(double newValueOfTerminalStates)
    {
        valueOfTerminalStates = newValueOfTerminalStates;
    }
    
}
