package net.pakl.neuralnet;
import java.io.Serializable;

public class SimpleRecurrentNet implements Serializable
{
    java.util.Random randomNumberGenerator;

    int numNeurons;
    double learningRate;
    double momentumTerm;
    double[] activity;
    double[] netInput;

    protected double[][] weight;
    boolean[][] connected;
    double [][] totalDeltaWij;		// for batch update
    boolean enableBatchUpdate = false;  // default no batch
    
    double initialRandomWeightFactor = 0.01d;
    
    // For example, a 3-2-3 network would have breaks at 3, 5, 8
    int [] layerBreaks;                                 

    // A bias node is recommended so neurons have thresholds.
    public static final int NUM_BIAS_NEURONS   = 0;     

    // The activity of the bias remains -1, but weights can change.
    public static final double BIAS_ACTIVATION = 0;     
    
    public static final long serialVersionUID = 2564253280646345832L;
    // Number of EXTRA neurons in input layer for copied context neurons; automatically set 
    // in constructor to be to the number of hidden neurons in the first layer.
    int numContextNeurons;                              

    double [][] previousDeltaWij;
    

    public SimpleRecurrentNet()
    {}

    public SimpleRecurrentNet(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks, double newInitialRandomWeightFactor)
    {
        this.initialRandomWeightFactor = newInitialRandomWeightFactor;
        init(randomSeed, numNeurons, learningRate, momentumTerm, layerBreaks);
    }

    public SimpleRecurrentNet(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks)
    {
        init(randomSeed, numNeurons, learningRate, momentumTerm, layerBreaks);
    }
    
    public void init(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks)
    {
        randomNumberGenerator = new java.util.Random(randomSeed);

        System.out.println("SimpleRecurrentNet 3.1 created with random seed " + randomSeed + ". First random number is: ");

        System.out.println(randomNumberGenerator.nextDouble());
        numContextNeurons = layerBreaks[1] - layerBreaks[0];

        System.out.println("numContextNeurons = " + numContextNeurons);
        netInput = new double[numNeurons + NUM_BIAS_NEURONS + numContextNeurons];
        activity = new double[numNeurons + NUM_BIAS_NEURONS + numContextNeurons];
        weight = new double[numNeurons + NUM_BIAS_NEURONS + numContextNeurons][numNeurons + NUM_BIAS_NEURONS + numContextNeurons];
        connected = new boolean[numNeurons + NUM_BIAS_NEURONS + numContextNeurons][numNeurons + NUM_BIAS_NEURONS + numContextNeurons];
        previousDeltaWij = new double[numNeurons + NUM_BIAS_NEURONS + numContextNeurons][numNeurons + NUM_BIAS_NEURONS + numContextNeurons];

        System.out.println("Layerbreaks initially are:" + vectorToText(layerBreaks));
        this.layerBreaks = incrementAllBy(layerBreaks, NUM_BIAS_NEURONS + numContextNeurons);
        System.out.println("Layerbreaks final are:" + vectorToText(this.layerBreaks));
        this.numNeurons = numNeurons + NUM_BIAS_NEURONS + numContextNeurons;
        this.learningRate = learningRate;
        this.momentumTerm = momentumTerm;
        
        setUpConnectivity();   // STRUCTURE WILL BE:  BIAS UNITS, CONTEXT UNITS, INPUT UNITS
        System.out.println("NUMNEURONS = "+this.numNeurons);
        System.out.println("totalDeltaWij initialized in case of batchUpdate...");
        totalDeltaWij = new double[this.numNeurons][this.numNeurons];
    }

	public void setBatchUpdate(boolean newValue)
	{
		System.out.print("enableBatchUpdate changed from " + enableBatchUpdate + " to ");
		enableBatchUpdate = newValue;
		System.out.println(enableBatchUpdate);
	}

    public void setUpConnectivity()
    {
        for (int sourceLayer = 0; sourceLayer < layerBreaks.length; sourceLayer++)
        {
            int destLayer = sourceLayer + 1;
            
            for (int pre = 0 + NUM_BIAS_NEURONS; pre < numNeurons; pre++)
            {
                for (int post = 0 + NUM_BIAS_NEURONS; post < numNeurons; post++)
                {
                    if (inLayer(sourceLayer, pre) && inLayer(destLayer, post))
                    {
                        connected[pre][post] = true;
                        weight[pre][post] = smallRandomWeight();
                    }
                }
            }
        }
        
        // -------------------------------------------------------
        // CONNECT THE BIAS NODE TO ALL HIDDEN AND OUTPUT NEURONS
        // -------------------------------------------------------
        for (int pre = 0; pre < NUM_BIAS_NEURONS; pre++)
        {
            for (int destLayer = 1; destLayer < layerBreaks.length; destLayer++)
            {
                for (int post = 0; post < numNeurons; post++)
                {
                    if (inLayer(destLayer, post))
                    {
                        connected[pre][post] = true;
                        weight[pre][post] = smallRandomWeight();
                    }
                }
            }
        }
    }
    
    /** Indicates whether a neuron is in a particular layer */
    private boolean inLayer(int layerNumber, int neuronNumber)
    {
        if (layerNumber >= layerBreaks.length) return false;
        if (layerNumber == 0)
        {
            if ((neuronNumber >= 0) && (neuronNumber < layerBreaks[0])) return true;
        }
        else
        {
            if ((neuronNumber >= layerBreaks[layerNumber-1]) && (neuronNumber < layerBreaks[layerNumber])) return true;
        }
        return false;
    }

    private boolean isContextNeuron(int neuronNumber)
    {
	if ((neuronNumber >= (NUM_BIAS_NEURONS)) && (neuronNumber < NUM_BIAS_NEURONS + numContextNeurons))
	{
		return true;
	}
	return false;
    }
    

    public void feedforward(double[] inputPattern)
    {
        // ---------------------------------------------------
        // ACTIVATE THE BIAS NEURONS
        // ---------------------------------------------------
        for (int i = 0; i < NUM_BIAS_NEURONS; i++)
        {
            netInput[i] = BIAS_ACTIVATION;                                  // Bias node/neuron.
            //activity[i] = sigmoid(BIAS_ACTIVATION);                                  // Bias node/neuron.
	    activity[i] = this.BIAS_ACTIVATION;
        }
        
        // ---------------------------------------------------
        // ACTIVATE ALL OTHER NON-BIAS NEURONS
        // ---------------------------------------------------
        int inputPatternIndex = 0;
        for (int post = 0 + NUM_BIAS_NEURONS + numContextNeurons; post < numNeurons; post++)
        {
            activity[post] = 0;
            netInput[post] = 0;
            
            if (inputPatternIndex < inputPattern.length) 
            { 
	    	netInput[post] = inputPattern[inputPatternIndex];
                activity[post] = inputPattern[inputPatternIndex];         // For input neurons
	        inputPatternIndex++;
            }
            else
            {
                for (int pre = 0; pre < numNeurons; pre++)
                {
                    if (connected[pre][post])
                    {
                          netInput[post] = netInput[post] + (weight[pre][post] * activity[pre]);
                    }
                }
                activity[post] = sigmoid(netInput[post]);       // This unit must apply sigmoid to itself before spreading activity to others.
            }
        }
    }
    
    public double [][] getDerivativeAgainstWeights(double[] targetPattern)
    {
        double[][] deltaWij = new double[numNeurons][numNeurons];
        //totalError = 0.0d;
        double[] error = new double[numNeurons];
        int i = 0, j = 0;
        
        // ---------------------------------------------------
        // FIND ERROR FOR ALL OUTPUT UNITS BASED ON TARGET
        // ---------------------------------------------------        
        j = targetPattern.length - 1;
        for (i = numNeurons-1; i >= numNeurons-targetPattern.length; i--)
        {
            error[i] = (targetPattern[j] - activity[i]) * sigmoidDerivative(activity[i]);
            // ALTERNATIVE part 1 of 2: error[i] = (activity[i] - targetPattern[j]) * sigmoidDerivative(activity[i]);
            j--;
        }
        int lastNonOutputUnit = i;

        // ---------------------------------------------------
        // FIND WEIGHTED ERROR FOR ALL NON-OUTPUT UNITS 
        // ---------------------------------------------------
        for (i = lastNonOutputUnit; i >= 0; i--)
        {
            error[i] = 0;
            for (j = i + 1; j < numNeurons; j++)
            {
                if (connected[i][j]) error[i] += (error[j] * weight[i][j]);
            }
            error[i] = error[i] * sigmoidDerivative(activity[i]);
        }
      
        // ---------------------------------------------------
        // USE THIS ERROR INFO TO COMPUTE DELTA WEIGHTS
        // ---------------------------------------------------
        for (int pre = 0; pre < numNeurons; pre++)
        {
            for (int post = 0; post < numNeurons; post++)
            {
                if (connected[pre][post])
                {
                    deltaWij[pre][post] = (learningRate * activity[pre] * error[post]);
                    // ALTERNATIVE part 2 of 2:  deltaWij[pre][post] = -1.0d * (learningRate * activity[pre] * error[post]);
                }
            }
        }
        return deltaWij;
    }
        
    double sumSquaredError;
    public double getSumSquaredError()
    {
        return sumSquaredError;
    }

    //private double totalError;
    public void backpropogate(double[] targetPattern)
    {
        int j = 0;
        for (int i = numNeurons-targetPattern.length; i < numNeurons; i++)
        {
            sumSquaredError += (targetPattern[j] - activity[i]) * (targetPattern[j] - activity[i]);
            j++;
        }

        // ---------------------------------------------------
        // UPDATE THE WEIGHTS
        // ---------------------------------------------------
        if (enableBatchUpdate == false)
        {
            double [][] deltaWij = getDerivativeAgainstWeights(targetPattern);
            updateWeights(deltaWij);
        }
        else
        {
//            if (totalDeltaWij == null) { totalDeltaWij = new double[numNeurons][numNeurons]; }
            sumInto(totalDeltaWij, getDerivativeAgainstWeights(targetPattern));
        }
            
        // ---------------------------------------------------
        // SRN -- copy excitation to context units
        // ---------------------------------------------------
        copyHiddenExcitationToContextUnits();
    }
    
	private void sumInto(double [][] a, double [][] b)
	{
		for (int i = 0; i < a.length; i++)
		{
			for (int j = 0; j < a[0].length; j++)
			{
				a[i][j] = a[i][j] + b[i][j];
			}
		}
	}

	private double [][] sum(double [][] a, double [][] b)
	{
		double [][] result = new double[a.length][a[0].length];
		for (int i = 0; i < a.length; i++)
		{
			for (int j = 0; j < a[0].length; j++)
			{
				result[i][j] = a[i][j] + b[i][j];
			}
		}
		return result;
	}

	/** Call this at the end of your sequence to update all the weights; discards all but will eventually keep momentum here too. */
	public void batchUpdateWeights()
	{
		updateWeights(totalDeltaWij);
		totalDeltaWij = new double [numNeurons][numNeurons];  /** Discard all.  NOTE -- should keep momentum here? */
	}

	private void updateWeights(double [][] deltaWij)
	{
        	for (int pre = 0; pre < numNeurons; pre++)
        	{
        	    for (int post = 0; post < numNeurons; post++)
        	    {
        	        if (connected[pre][post])
        	        {
                        deltaWij[pre][post] = deltaWij[pre][post] + momentumTerm * (previousDeltaWij[pre][post]);
                        weight[pre][post] = weight[pre][post] + deltaWij[pre][post];
                        previousDeltaWij[pre][post] = deltaWij[pre][post];
        	        }
        	    }
        	}
	}
    
    public void copyHiddenExcitationToContextUnits()
    {
        // -------------------------------------------------------
        // ACTIVATE ALL CONTEXT UNITS BASED ON FIRST HIDDEN LAYER
        // -------------------------------------------------------
        int hiddenLayerIndex = layerBreaks[0];
        for (int contextNeuron = 0 + NUM_BIAS_NEURONS; contextNeuron < NUM_BIAS_NEURONS + numContextNeurons; contextNeuron++)
        {
            activity[contextNeuron] = activity[hiddenLayerIndex];
            hiddenLayerIndex++;
        }
    }

    public void clearSSEMeasure()
    {
        sumSquaredError = 0;
    }
    
    public void clearContextUnits()
    {
        for (int contextNeuron = 0 + NUM_BIAS_NEURONS; contextNeuron < NUM_BIAS_NEURONS + numContextNeurons; contextNeuron++)
        {
            activity[contextNeuron] = 0d;
            netInput[contextNeuron] = 0d;
        }
        //this.totalError = 0d;
    }
    
    // ------------------------------------------------------------------------------------
    // -----------------------      SUPPORT FUNCTIONS -------------------------------------
    // ------------------------------------------------------------------------------------
    
    protected double sigmoid(double x)
    {
       return (1.0d / (1.0d + Math.exp(-x))); 
    }
    
    protected double sigmoidDerivative(double x)
    {
        return (x * (1.0d - x));
    }    
    
    private double smallRandomWeight()
    {
        return (initialRandomWeightFactor) * (randomNumberGenerator.nextDouble() - 0.5d);
    }
    
    public boolean isConnected(int i, int j)
    {
        return connected[i + NUM_BIAS_NEURONS][j + NUM_BIAS_NEURONS];
    }
    
    public void setWeight(int i, int j, double value)
    {
        this.weight[i][j] = value;
    }    

    public void disconnect(int pre, int post)
    {
    	connected[pre + NUM_BIAS_NEURONS + numContextNeurons][post + NUM_BIAS_NEURONS + numContextNeurons] = false;
    }

    public double getActivity(int i)
    {
        return activity[i + NUM_BIAS_NEURONS + numContextNeurons];
    }
    
    public double [] getOutputLayerActivity()
    {
        int firstOutputNeuron = layerBreaks[layerBreaks.length-1-1];
        double [] result = new double[numNeurons-firstOutputNeuron];
        int j = 0;
        for (int i = firstOutputNeuron; i < numNeurons; i++)
        {
            result[j] = activity[i];
            j++;
        }
        return result;
    }    
    
    private int[] incrementAllBy(int [] a, int value)
    {
        int [] result = new int[a.length];
        for (int i = 0; i < a.length; i++)
        {
            result[i] = value + a[i];
        }
        return result;
    }

    public void setLearningRate(double newRate)
    {
	this.learningRate = newRate;
    }
    
    public void setMomentumTerm(double newMomentum)
    {
	this.momentumTerm = newMomentum;
    }
    
    public double getLearningRate()
    {
	return this.learningRate;
    }
    
    public double getMomentumTerm()
    {
	return this.momentumTerm;
    }
    
    // ---------------------------------------------------------------------------
    //  Extra non-essential functions
    // ---------------------------------------------------------------------------
    
    /** Allows for copying weights between two neural networks */
    public void copyWeightsTo(Perceptron p)
    {
        for (int i = 0; i < numNeurons; i++)
        {
            for (int j = 0; j < numNeurons; j++)
            {
                p.setWeight(i, j, this.weight[i][j]);
            }
        }
    }
    
    public String getTextRepresentation()
    {
        String result = "";
        for (int i = 0; i < this.numNeurons; i++)
        {
            for (int j = 0; j < this.layerBreaks.length; j++)
            {
                if (this.layerBreaks[j] == i) result = result + ("\n");
            }
            if (i < NUM_BIAS_NEURONS) { result = result + ("B "); } 
            else if (i < numContextNeurons+NUM_BIAS_NEURONS) { result = result + ("C "); }
            else result = result + ("N ");
        }
        return result;
    }
    
    private String vectorToText(double [] a)
    {
        String result = "";
        for (int i = 0; i < a.length; i++)
        {
            result = result + (a[i]+" ");
        }
        return result + "\n";
    }        
    
    private String vectorToText(int [] a)
    {
        String result = "";
        for (int i = 0; i < a.length; i++)
        {
            result = result + (a[i]+" ");
        }
        return result + "\n";
    }            
    
    public String getConnectivity()
    {
        String result = "";
        for (int i = 0; i < this.numNeurons; i++)
        {
            result = result + (i + "\t->");
            for (int j = 0; j < this.numNeurons; j++)
            {
                System.out.println("connected["+i+"]["+j+"] = "+connected[i][j]);
                if (connected[i][j])
                {
                    result = result + (" " + j + "@("+weight[i][j]+")");
                }
            }
            result = result + ("\n");
        }
        return result;
    }
    
    public String getGraphviz(double threshold, double lineWidthScale)
    {
        String result = "digraph srn {\nranksep=5;\n";
        String arrowhead;
        int currentLayer = 0;
        

        result += "subgraph cluster_layer"+currentLayer+"{\n";
        for (int i = 0; i < this.numNeurons; i++)
        {
            if (!inLayer(currentLayer,i))
            {
                // result += " [style=invis]}\n";
                result += " }\n";                
                currentLayer++;
                result += "subgraph cluster_layer"+currentLayer+"{\n";
            }
            // if (!isContextNeuron(i))            result+= (i+"->");
            if (!isContextNeuron(i))            result+= (i+";");
        }
        result +="}\n";
        
        result += "subgraph cluster_contextLayer {\n";
        for (int i = 0; i < this.numNeurons; i++)
        {
            // if (isContextNeuron(i)) result += (i+"->");
            if (isContextNeuron(i)) result += (i+";");
        }
        // result +=" [style=invis]}\n";
        result +=" }\n";
        result = result.replaceAll("-> ", " ");
        result = result.replaceAll("->}", " }");
        
        
        for (int i = 0; i < this.numNeurons; i++)
        {
            for (int j = 0; j < this.numNeurons; j++)
            {
                if (connected[i][j])
                {
                    if (Math.abs(weight[i][j]) > threshold)
                    {
                        if (weight[i][j] > 0) { arrowhead = "inv"; } else { arrowhead = "tee"; }
                        result = result + i + " -> " + j + " ";
                        result += "[style=\"setlinewidth("+lineWidthScale*Math.abs(weight[i][j])+")\" arrowhead=\"" +
                                arrowhead +"\"";

                        //result += " label=\"" + weight[i][j] + "\"";
                        if (arrowhead.equals("inv"))
                        {
                            result += " color=\"red\"";
                        }
                        else
                        {
                            result += " color=\"blue\"";
                        }
                        result += "];\n";
                    }
                }
            }

        }

        // DRAW CONTEXT LAYER COPY LINKS
        int hiddenLayerIndex = layerBreaks[0];
        for (int contextNeuron = 0 + NUM_BIAS_NEURONS; contextNeuron < NUM_BIAS_NEURONS + numContextNeurons; contextNeuron++)
        {
            result += hiddenLayerIndex + "->" + contextNeuron + "[constraint=false arrowhead=\"inv\" style=\"dotted\"];\n";
            hiddenLayerIndex++;
        }
        
        result += "}\n";
        return result;
    }

}
