package net.pakl.neuralnet;
import java.io.*;
import java.util.*;

public class SRNTrainer
{        
    public static void main(String args[]) throws Exception
    {
        new SRNTrainer().run();
    }
     
    int hiddenLayerSize = 10;           // number of hidden units (also, size of context feedback layer)
    double learningRate = 0.1;          // rate at which weights change according to gradient
    double momentumTerm = 0;            // continued change in direction of old gradient
    long seed = 123;                    // random see for initialization of weights
    int NUM_TRAINING_TRIALS = 80000;    // training iterations through data sequence
    
    /** Reads in data, rescales it, and then trains a simple recurrent network on 
     *  the first half of the data and tests on the second half. */
    public void run() throws Exception
    {
        int [] layerBreaks = {4, 4 + hiddenLayerSize, 4 + hiddenLayerSize + 2}; 
        
        SimpleRecurrentNet net = new SimpleRecurrentNet(seed, 
                4 + hiddenLayerSize + 2, 
                learningRate, momentumTerm, layerBreaks);
        
        System.out.println("hiddenLayerSize = " + hiddenLayerSize);
        System.out.println("learningRate = " + learningRate);
        System.out.println("momentumTerm = " + momentumTerm);
        System.out.println("NUM_TRAINING_TRIALS = " + NUM_TRAINING_TRIALS);
        System.out.println("seed = " + seed);
        List <double[]> sources = readCSV("sources.csv");
        List <double[]> targets = readCSV("targets.csv");
        System.out.println("There are " + sources.size() + " patterns.");

        
        System.out.println("RESCALING");
        rescale(sources, getMaxes(sources), getMins(sources));
        
        double previousX = 0;
        double previousY = 0;
        for (double [] data : targets)
        {
            if (data[0] == -100) { if (previousX > 0) data[0] = 1; else data[0] = -1; } else { previousX = data[0]; }
            if (data[1] == -100) { if (previousY > 0) data[1] = 1; else data[0] = -1; } else { previousY = data[1]; }
        }
        rescale(targets, getMaxes(targets), getMins(targets));
        
        System.out.println("TRAINING");
        for (int i = 0; i < NUM_TRAINING_TRIALS; i++)
        {
            net.clearSSEMeasure();
            net.clearContextUnits();
            long start = System.currentTimeMillis();
            for (int data = 0; data < sources.size(); data++)
            //for (int data = 0; data < sources.size()/2; data++)
            {
                net.feedforward(sources.get(data));
                net.backpropogate(targets.get(data));
                net.copyHiddenExcitationToContextUnits();
                if (data % 100 == 0) net.clearContextUnits();
            }
            System.out.println((System.currentTimeMillis()-start) + "ms " + i + " " + net.getSumSquaredError());
            
            if (new File("endrun.msg").exists()) break;
        }
        
        System.out.println("TESTING");
        net.clearContextUnits();
        for (int data = 0; data < sources.size(); data++)
//        for (int data = sources.size()/2; data < sources.size(); data++)
        {
                net.feedforward(sources.get(data));
                System.out.println(
                        vectorToString(targets.get(data)) 
                        + " " 
                        + vectorToString(net.getOutputLayerActivity())
                        );
                net.copyHiddenExcitationToContextUnits();            
                if (data % 100 == 0) net.clearContextUnits();
        }
    }
    
    // --------------------------------------------------------------------------------
    // END NEURAL NETWORK CODE
    // --------------------------------------------------------------------------------
    
        
    /** Reads CSV data into Lists of arrays. */
    public List<double[]> readCSV(String filename) throws Exception
    {
        List<double[]> result = new ArrayList<double[]>();
        Scanner s = new Scanner(new File(filename));
        while (s.hasNextLine())
        {
            List<Double> row = new ArrayList<Double>();
            Scanner t = new Scanner(s.nextLine()).useDelimiter(",");
            while (t.hasNextDouble())
            {
                row.add(t.nextDouble());
            }
            double [] rowAsArray = new double[row.size()];
            for (int i = 0; i < row.size(); i++)
            {
                rowAsArray[i] = row.get(i);
            }
            result.add(rowAsArray);
        }    
        
        return result;
    }
    
    private String vectorToString(double [] a)
    {
        String result = "";
        for (int i = 0; i < a.length; i++)
        {
            result += a[i]+" ";
        }
        return result;
    }
    
    private void showVector(double [] a)
    {
        for (int i = 0; i < a.length; i++)
        {
            System.out.print(a[i]+" ");
        }
    }    
    
    private double[] getMaxes(List<double[]> data)
    {
        double [] result = new double[data.get(0).length];
        for (int i = 0; i < data.get(0).length; i++)
        {
            result[i] = Double.MIN_VALUE;
        }
        for (double [] row : data)
        {
            for (int i = 0; i < data.get(0).length; i++)
            {
                if (row[i] > result[i]) result[i] = row[i];
            }
        }
        return result;
    }

    private double[] getMins(List<double[]> data)
    {
        double [] result = new double[data.get(0).length];
        for (int i = 0; i < data.get(0).length; i++)
        {
            result[i] = Double.MAX_VALUE;
        }
        for (double [] row : data)
        {
            for (int i = 0; i < data.get(0).length; i++)
            {
                if (row[i] < result[i]) result[i] = row[i];
            }
        }
        return result;
    }
    
    private void rescale(List<double[]> data, double[] maxes, double[] mins)
    {
        for (double [] row : data)
        {
            for (int i = 0; i < row.length; i++)
            {
                row[i] =  (row[i] - mins[i]) / (maxes[i] - mins[i]);
            }
        }
    }

 }
