package net.pakl.rl.maze;

import net.pakl.rl.*;
import java.util.*;
import java.io.*;

/** Simple example to navigate through a grid world for for tutorial/intro 
 * purposes; does not load anything from a file but is
 * instead fully self-contained. */
public class MazeMainSimple
{
    static int mazeX = 10, mazeY = 10;
    static double discountFactor = 1.0;
    static int trainingTrialsToRun = 30;
    
    public static void main(String args[]) throws Exception
    {
        // -----------------------------------------------------------------------------------------------------
        // Set up a maze world, which contains:
        // - a list of states contained in the world, that should be visited.
        // - the transition function describing how state1+action = state2
        // -----------------------------------------------------------------------------------------------------
        MazeWorld world = new MazeWorld("maze");
        world.setLengths(mazeX, mazeY);
        world.setPObstacle(0.1);
        world.build();
        world.makeIntoTerminalState(new State2D(mazeX-1, mazeY-1));
        System.out.println(world.toText());

        // -----------------------------------------------------------------------------------------------------
        // Reinforcement Function
        // -----------------------------------------------------------------------------------------------------
        ReinforcementFunction rf = new ReinforcementFunction();
        rf.setDefaultReinforcement(-1.0);
        rf.setReward(new State2D(mazeX-1, mazeY-1), new Action2D(0,0), 0);
    
        // -----------------------------------------------------------------------------------------------------
        // A policy specifies what actions are allow from what states, e.g. don't move off edge of Maze world.
        // -----------------------------------------------------------------------------------------------------
        Toolbox toolbox = new Toolbox();
        ActionSet policy = toolbox.makeSimpleMazePolicy((MazeWorld) world); 

        // -----------------------------------------------------------------------------------------------------
        // The agent contains/executes the Reinforcement Learning algorithm.
        // -----------------------------------------------------------------------------------------------------
        Agent agent = createAgent(world, rf, policy, discountFactor);           
        
        ValueFunction valueFunction =  new ValueFunctionHashMap(world);
        ValueFunction tempNewValueFunction = new ValueFunctionHashMap(world);

        // -----------------------------------------------------------------------------------------------------
        // Training
        // -----------------------------------------------------------------------------------------------------
        for (int trial = 0; trial < trainingTrialsToRun; trial++)
        {
            System.out.println("Iteration " + trial);
            agent.performValueIteration(tempNewValueFunction, valueFunction);
            ValueFunction swap = tempNewValueFunction;
            tempNewValueFunction = valueFunction;
            valueFunction = swap;
        }
        showValueFunction(valueFunction);

        // -----------------------------------------------------------------------------------------------------
        // Now, let's see the actions that lead to the terminal state with most reward.
        // -----------------------------------------------------------------------------------------------------
        PolicyExtractor policyExtractor = new PolicyExtractor();                
        String policyString = policyExtractor.extractOptimalPolicy(policy, valueFunction, null, world, rf, discountFactor); 
        List optimalActions = policyExtractor.getOptimalActions();
        System.out.println(policyString); 
        System.out.println("\n" + "All Done."); 
    }
    
    private static void showValueFunction(ValueFunction valueFunction)
    {
        System.out.println("vf");
        java.text.DecimalFormat decimalFormat = new java.text.DecimalFormat("00.00");
        for (int j = 0; j < mazeX; j++)
        {

            for (int i = 0; i < mazeY; i++)
            {
                String s = decimalFormat.format(valueFunction.getValue(new State2D(i,j)));
                if (s.length() < 6) s = " " + s;
                System.out.print(s + " ");
            }
            System.out.println("   %vf ");
        }   
    }
    
    private static Agent createAgent(World world, ReinforcementFunction rf, ActionSet policy, double discountFactor)
    {
        Agent agent = new Agent("agent");
        agent.setWorld(world);				// Place the agent into the world.
        agent.setReinforcementFunction(rf);
        agent.setPolicy(policy);
        agent.setDiscountFactor(discountFactor);
        return agent;
    }    
}
