package net.pakl.rl;

import java.util.*;

//import edu.oswego.cs.dl.util.concurrent.*;
//import edu.emory.mathcs.backport.java.util.concurrent.*;
import java.util.concurrent.*;
/** 
  * This class represents a multi-threaded agent (therefore faster if 
 * you have multiple CPUs) appropriate for reinforcement learning with full 
 * Value Iteration, since 
 *  during value iteration each state can be updated independently of others (as long 
 * as testTrainSameVf is false). 
  */
public class AgentParallelized extends Agent
{
        protected int threadsAtATime = 2;
        static int uniqueThreadNumber = 0;

	public AgentParallelized(String newName, int numThreads)
	{
	    super(newName);
            this.name = newName;
            this.threadsAtATime = numThreads;
            System.err.println("! Agent " + getName() + " has been created.");
	}

	public AgentParallelized(int numThreads)
	{
            this("ParalellizedAgent", numThreads);
            this.threadsAtATime = numThreads;
	}
        
        private synchronized int getUniqueThreadNumber()
        {
            uniqueThreadNumber++;
            return uniqueThreadNumber;
        }

	public ValueFunction performValueIteration(final ValueFunction newValueFunction, final ValueFunction valueFunction)
	{
            final Iterator stateIterator = ((ValueFunctionHashMap)valueFunction).getKeySetIterator();
            totalDelta = 0;
            uniqueThreadNumber = 0;
            ArrayBlockingQueue myJobQueue = new ArrayBlockingQueue(threadsAtATime);
            //ExecutorService service = Executors.newFixedThreadPool(threadsAtATime);
            ThreadPoolExecutor pool = new ThreadPoolExecutor(threadsAtATime, threadsAtATime, 100L, TimeUnit.MILLISECONDS, myJobQueue, new ThreadPoolExecutor.CallerRunsPolicy());
            while (stateIterator.hasNext())
            {        
                    //service.execute(
                    pool.execute(
                    new Runnable()
                    {
                    public void run()
                    {
                        State currentState = null;
                        int myNumber = getUniqueThreadNumber();
                        long startTime = System.currentTimeMillis();
                        System.out.println(myNumber + " start");
                        for (int i = 0; i < 100000; i++)
                        {
                            if (!stateIterator.hasNext()) break;
                            currentState = (State) stateIterator.next();
                            if (currentState != null)
                            {
                                performValueIterationUpdateOnState(newValueFunction, valueFunction, currentState);
                            }
                        }
                        System.out.println(myNumber + " done, duration="+(System.currentTimeMillis()-startTime)+"ms");
                        
                    }} // end 
                    );  // end new addThread()                    
            }
//            service.shutdown();
              pool.shutdown();
              try
              {
              pool.awaitTermination(10000L, TimeUnit.SECONDS);
              }
              catch (Exception e) { e.printStackTrace(); }
            if (pool.getPoolSize() > 0)
            {
                throw new RuntimeException("There were still threads running and we got here.");
            }
            averageDelta = totalDelta / this.world.getNumberOfStates();
            System.out.println("avgDelta = " + averageDelta + " maxDelta was " + maximumDelta + " totalDelta was " + totalDelta); 
         
            return newValueFunction;
	}

}
