/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.rl.maze;

import java.text.DecimalFormat;
import java.util.List;
import net.pakl.rl.ActionSet;
import net.pakl.rl.Agent;
import net.pakl.rl.PolicyExtractor;
import net.pakl.rl.ReinforcementFunction;
import net.pakl.rl.ValueFunction;
import net.pakl.rl.ValueFunctionHashMap;
import net.pakl.rl.World;
import net.pakl.rl.maze.Action2D;
import net.pakl.rl.maze.MazeWorld;
import net.pakl.rl.maze.State2D;
import net.pakl.rl.maze.Toolbox;

public class MazeMainSimple {
    static int mazeX = 10;
    static int mazeY = 10;
    static double discountFactor = 1.0;
    static int trainingTrialsToRun = 30;

    public static void main(String[] args) throws Exception {
        MazeWorld world = new MazeWorld("maze");
        world.setLengths(mazeX, mazeY);
        world.setPObstacle(0.1);
        world.build();
        world.makeIntoTerminalState(new State2D(mazeX - 1, mazeY - 1));
        System.out.println(world.toText());
        ReinforcementFunction rf = new ReinforcementFunction();
        rf.setDefaultReinforcement(-1.0);
        rf.setReward(new State2D(mazeX - 1, mazeY - 1), new Action2D(0, 0), 0.0);
        Toolbox toolbox = new Toolbox();
        ActionSet policy = toolbox.makeSimpleMazePolicy(world);
        Agent agent = MazeMainSimple.createAgent(world, rf, policy, discountFactor);
        ValueFunctionHashMap valueFunction = new ValueFunctionHashMap(world);
        ValueFunctionHashMap tempNewValueFunction = new ValueFunctionHashMap(world);
        for (int trial = 0; trial < trainingTrialsToRun; ++trial) {
            System.out.println("Iteration " + trial);
            agent.performValueIteration(tempNewValueFunction, valueFunction);
            ValueFunctionHashMap swap = tempNewValueFunction;
            tempNewValueFunction = valueFunction;
            valueFunction = swap;
        }
        MazeMainSimple.showValueFunction(valueFunction);
        PolicyExtractor policyExtractor = new PolicyExtractor();
        String policyString = policyExtractor.extractOptimalPolicy(policy, valueFunction, null, world, rf, discountFactor);
        List optimalActions = policyExtractor.getOptimalActions();
        System.out.println(policyString);
        System.out.println("\nAll Done.");
    }

    private static void showValueFunction(ValueFunction valueFunction) {
        System.out.println("vf");
        DecimalFormat decimalFormat = new DecimalFormat("00.00");
        for (int j = 0; j < mazeX; ++j) {
            for (int i = 0; i < mazeY; ++i) {
                String s = decimalFormat.format(valueFunction.getValue(new State2D(i, j)));
                if (s.length() < 6) {
                    s = " " + s;
                }
                System.out.print(s + " ");
            }
            System.out.println("   %vf ");
        }
    }

    private static Agent createAgent(World world, ReinforcementFunction rf, ActionSet policy, double discountFactor) {
        Agent agent = new Agent("agent");
        agent.setWorld(world);
        agent.setReinforcementFunction(rf);
        agent.setPolicy(policy);
        agent.setDiscountFactor(discountFactor);
        return agent;
    }
}

