/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.neuralnet;

import net.pakl.neuralnet.SimpleRecurrentNet;
import net.pakl.neuralnet.VectorComparer;

public class SimpleRecurrentNetTest2 {
    public static void main(String[] args) {
        double[][] inputSeq1 = new double[][]{{1.0, 0.0, 0.0}, {1.0, 0.0, 0.0}, {1.0, 0.0, 0.0}};
        double[][] outputSeq1 = new double[][]{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}};
        double learningRate = 0.05;
        double momentumTerm = 0.0;
        int[] layerBreaks = new int[]{3, 5, 8};
        SimpleRecurrentNet net = new SimpleRecurrentNet(123456L, 8, learningRate, momentumTerm, layerBreaks);
        SimpleRecurrentNetTest2.trainNetwork(net, inputSeq1, outputSeq1, 55000);
        SimpleRecurrentNetTest2.testNetwork(net, inputSeq1, outputSeq1);
    }

    private static void trainNetwork(SimpleRecurrentNet net, double[][] inputList, double[][] targetOutputList, int trials) {
        for (int i = 0; i < trials; ++i) {
            net.clearContextUnits();
            for (int j = 0; j < inputList.length; ++j) {
                net.feedforward(inputList[j]);
                net.backpropogate(targetOutputList[j]);
                net.copyHiddenExcitationToContextUnits();
            }
            if (i % 1000 != 0) continue;
            System.out.println("Training trial " + i + " SumSquared Error: " + net.getSumSquaredError());
        }
    }

    private static void testNetwork(SimpleRecurrentNet net, double[][] inputList, double[][] expectedOutputList) {
        net.clearContextUnits();
        for (int j = 0; j < inputList.length; ++j) {
            net.feedforward(inputList[j]);
            System.out.print("goal : ");
            SimpleRecurrentNetTest2.showVector(expectedOutputList[j]);
            System.out.print("net  : ");
            SimpleRecurrentNetTest2.showOutput(net);
            System.out.println("\nNormalized Dot = " + VectorComparer.normalizedDotProduct(expectedOutputList[j], SimpleRecurrentNetTest2.getOutput(net)));
            net.copyHiddenExcitationToContextUnits();
        }
    }

    private static double[] getOutput(SimpleRecurrentNet net) {
        double[] result = new double[3];
        for (int i = 5; i < 8; ++i) {
            result[i - 5] = net.getActivity(i);
        }
        return result;
    }

    private static void showVector(double[] a) {
        for (int i = 0; i < a.length; ++i) {
            System.out.print(a[i] + " ");
        }
    }

    private static void showOutput(SimpleRecurrentNet net) {
        for (int i = 5; i < 8; ++i) {
            System.out.print(net.getActivity(i) + " ");
        }
    }
}

