/*
 * SimpleRecurrentNetTest.java
 *
 * Created on February 25, 2004, 2:17 PM
 */

package net.pakl.neuralnet;

/**
 *
 * @author  patryk
 */
public class SimpleRecurrentNetTest2
{
    
    public SimpleRecurrentNetTest2()
    {
    }
    
    public static void main(String[] args)
    {
//        double [][] inputs = {  {0.1, 0.2, 0.3} }; 
//        double [][] outputs = { {0.1, 0.3, 0.5} };
        double [][] inputSeq1 = {  {1, 0, 0}, 
                                {1, 0, 0},
                                {1, 0, 0}
                              };
        double [][] outputSeq1 = { {1, 0, 0},
                                 {0, 1, 0},
                                 {0, 0, 1}
        
                            };
                            
                            
        double learningRate = 0.05;
        double momentumTerm = 0.0;
        int [] layerBreaks = {3, 5, 8};
       
        SimpleRecurrentNet net = new SimpleRecurrentNet(123456, 8, learningRate, momentumTerm, layerBreaks);

        // for (int i = 0; i < outputs.length; i++) { outputs[i] = VectorComparer.normalize(outputs[i]); }
        trainNetwork(net, inputSeq1, outputSeq1, 55000);
        testNetwork(net, inputSeq1, outputSeq1);
    }

    private static void trainNetwork(SimpleRecurrentNet net, double[][] inputList, double[][] targetOutputList, int trials)
    {
        for (int i = 0; i < trials; i++)
        {
            net.clearContextUnits();
            for (int j = 0; j < inputList.length; j++)
            {
                net.feedforward(inputList[j]);
                net.backpropogate(targetOutputList[j]);
                net.copyHiddenExcitationToContextUnits();
            }
            if (i % 1000 == 0) System.out.println("Training trial " + i + " SumSquared Error: " + net.getSumSquaredError());
        }
        //System.out.println(net.getConnectivity());
    }

    private static void testNetwork(SimpleRecurrentNet net, double[][] inputList, double[][] expectedOutputList)
    {
        net.clearContextUnits();
        for (int j = 0; j < inputList.length; j++)
        {
            net.feedforward(inputList[j]);
            System.out.print("goal : "); showVector(expectedOutputList[j]);
            System.out.print("net  : "); showOutput(net);
            System.out.println("\nNormalized Dot = " + VectorComparer.normalizedDotProduct(expectedOutputList[j],  getOutput(net)));
            net.copyHiddenExcitationToContextUnits();
        }
    }
    
    private static double[] getOutput(SimpleRecurrentNet net)
    {
        double result[] = new double[8-5];
        for (int i = 5; i < 8; i++)
        {
            result[i-5] = net.getActivity(i);
        }
        return result;
    }
    
    private static void showVector(double [] a)
    {
        for (int i = 0; i < a.length; i++)
        {
            System.out.print(a[i]+" ");
        }
    }

    private static void showOutput(SimpleRecurrentNet net)
    {
        for (int i = 5; i < 8; i++)
        {
            System.out.print(net.getActivity(i) + " ");
        }
    }
}
