package net.pakl.neuralnet;
import java.io.Serializable;

public class Perceptron implements Serializable
{
    java.util.Random randomNumberGenerator;
    int numNeurons;
    public double learningRate;
    double momentumTerm;
    double[] activity;
    double[] netInput;
    
    protected double[][] weight;
    boolean[][] connected;
    int [] layerBreaks;                                 // a 3-2-3 network would have breaks at 3, 5, 8

    public static int NUM_BIAS_NEURONS   = 1;           // A bias node is recommended so neurons have thresholds.
    public static final double BIAS_ACTIVATION = -1;    // The activity of the bias remains -1, but weights can change.
    double [][] previousDeltaWij;
    
    public int getNumNeurons()
    {
     return numNeurons;
    }
    
    public Perceptron()
    {
    }
    
    public Perceptron(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks)
    {
        constructPerceptron(randomSeed, numNeurons, learningRate, momentumTerm, layerBreaks);
    }

    public Perceptron(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks, int new_NUM_BIAS_NEURONS)
    {
        System.out.println("Note: Number of BIAS Neurons manually changed from default of " + NUM_BIAS_NEURONS + " to " + new_NUM_BIAS_NEURONS);
        this.NUM_BIAS_NEURONS = new_NUM_BIAS_NEURONS;
        constructPerceptron(randomSeed, numNeurons, learningRate, momentumTerm, layerBreaks);
    }    
    
    protected void constructPerceptron(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks)
    {
        randomNumberGenerator = new java.util.Random(randomSeed);
        activity = new double[numNeurons + NUM_BIAS_NEURONS];
        netInput = new double[numNeurons + NUM_BIAS_NEURONS];
        weight = new double[numNeurons + NUM_BIAS_NEURONS][numNeurons + NUM_BIAS_NEURONS];
        connected = new boolean[numNeurons + NUM_BIAS_NEURONS][numNeurons + NUM_BIAS_NEURONS];
        previousDeltaWij = new double[numNeurons + NUM_BIAS_NEURONS][numNeurons + NUM_BIAS_NEURONS];
        this.layerBreaks = incrementAllBy(layerBreaks, NUM_BIAS_NEURONS);
        
        this.numNeurons = numNeurons + NUM_BIAS_NEURONS;
        this.learningRate = learningRate;
        this.momentumTerm = momentumTerm;
        
        setUpConnectivity();
    }    
    protected double sigmoid(double x)
    {
       return (1.0d / (1.0d + Math.exp(-x))); 
    }
    
    protected double sigmoidDerivative(double x)
    {
        return (sigmoid(x) * (1.0d - sigmoid(x)));
    }    
    
    double randomInitialScale = 1.0;
    
    
    public void reinitializeRandomWeightsWith(double scale)
    {
        randomInitialScale = scale;
        System.out.println("Perceptron: randomInitialScale = "+randomInitialScale);
        this.setUpConnectivity();
    }
    
    private double smallRandomWeight()
    {
        return (randomInitialScale) * ((randomNumberGenerator.nextDouble() - 0.5d)/5.d);
    }

    /** Sets the weights of the neural network to random numbers from -.5 to +.5 and also connects the bias
      * node to all non-input neurons. */
    public void setUpConnectivity()
    {
        for (int sourceLayer = 0; sourceLayer < layerBreaks.length; sourceLayer++)
        {
            int destLayer = sourceLayer + 1;
            
            for (int pre = 0 + NUM_BIAS_NEURONS; pre < numNeurons; pre++)
            {
                for (int post = 0 + NUM_BIAS_NEURONS; post < numNeurons; post++)
                {
                    if (inLayer(sourceLayer, pre) && inLayer(destLayer, post))
                    {
                        connected[pre][post] = true;
                        weight[pre][post] = smallRandomWeight();
                    }
                }
            }
        }

        // CONNECT THE BIAS NODE TO ALL HIDDEN AND OUTPUT NEURONS
        for (int pre = 0; pre < NUM_BIAS_NEURONS; pre++)
        {
            for (int destLayer = 1; destLayer < layerBreaks.length; destLayer++)
            {
                for (int post = 0+NUM_BIAS_NEURONS; post < numNeurons; post++)
                {
                    if (inLayer(destLayer, post))
                    {
                        connected[pre][post] = true;
                        weight[pre][post] = smallRandomWeight();
                    }
                }
            }
        }
    }
    

    /** Allows for copying weights between two neural networks */
    public void copyWeightsTo(Perceptron p)
    {
        for (int i = 0; i < numNeurons; i++)
        {
            for (int j = 0; j < numNeurons; j++)
            {
                p.setWeight(i, j, this.weight[i][j]);
            }
        }
    }
    
    
    /** Indicates whether a neuron is in a particular layer */
    protected boolean inLayer(int layerNumber, int neuronNumber)
    {
        if (layerNumber >= layerBreaks.length) return false;
        if (layerNumber == 0)
        {
            if ((neuronNumber >= 0) && (neuronNumber < layerBreaks[0])) return true;
        }
        else
        {
            if ((neuronNumber >= layerBreaks[layerNumber-1]) && (neuronNumber < layerBreaks[layerNumber])) return true;
        }
        return false;
    }
    

    public void feedforward(double[] inputPattern)
    {
        // ACTIVATE THE BIAS NEURONS
        for (int i = 0; i < NUM_BIAS_NEURONS; i++)
        {
            netInput[i] = this.BIAS_ACTIVATION;
            //activity[i] = sigmoid(this.BIAS_ACTIVATION);                         // Bias node/neuron.
            activity[i] = this.BIAS_ACTIVATION;                         // Bias node/neuron.
        }
        
        // ACTIVATE ALL OTHER NON-BIAS NEURONS
        int inputPatternIndex = 0;
        for (int post = 0 + NUM_BIAS_NEURONS; post < numNeurons; post++)
        {
            activity[post] = 0;
            netInput[post] = 0;
            if (inputPatternIndex < inputPattern.length) 
            { 
                netInput[post] = inputPattern[inputPatternIndex];                  // For input neurons
                activity[post] = inputPattern[inputPatternIndex];         // For input neurons
                //activity[post] = sigmoid(inputPattern[inputPatternIndex]);         // For input neurons
                inputPatternIndex++;
            }
            else
            {
                for (int pre = 0; pre < post; pre++)
                {
                    if (connected[pre][post]) 
                    {
                        netInput[post] = netInput[post] + (weight[pre][post] * activity[pre]);
                    }
                }
                activity[post] = sigmoid(netInput[post]);       // This unit should apply sigmoid to itself before spreading activity to others.
            }            
        }
    }
    
    private double totalError;

//    public double [][] getDerivativeAgainstWeights()
//    {
//        double[][] derivative = new double[numNeurons][numNeurons];
//        for (int post = 0; post < numNeurons; post++)
//        {
//            for (int pre = 0; pre < numNeurons; pre++)
//            {
//                if (connected[pre][post])
//                {
//                    derivative[pre][post] = activity[pre] * sigmoidDerivative(netInput[post]);
//                }
//            }
//        }
//        return derivative;
//    }
    
    public double [][] getOutputDerivativeAgainstWeights()
    {
        double[][] derivative = new double[numNeurons][numNeurons];
        for (int post = 0; post < numNeurons; post++)
        {
            for (int pre = 0; pre < numNeurons; pre++)
            {
                if (connected[pre][post])
                {
                    if (inLayer(1, pre)) // Derivative against hidden-to-output weight
                    {
                        derivative[pre][post] = activity[post] * (1.0 - activity[post]) * activity[pre];
                    }
                    if (inLayer(0, pre)) // Derivative against input-to-hidden weight
                    {
                        int input = pre;
                        int hidden = post;
                        int output = this.getNumNeurons()-1;
                        derivative[input][hidden] = activity[output] * (1.0 - activity[output]) * weight[hidden][output] * activity[hidden] * (1.0 - activity[hidden]);
                    }
                }
            }
        }
        return derivative;
    }
    
    public double [][] getDeltaWij(double[] outputPattern)
    {
        double[][] deltaWij = new double[numNeurons][numNeurons];
        totalError = 0.0d;
        double[] error = new double[numNeurons];
        
        // ---------------------------------------------------
        // FIND ERROR FOR ALL OUTPUT UNITS
        // ---------------------------------------------------
        
        int inputIndex = outputPattern.length - 1;
        int outputNode;
        for (outputNode = numNeurons-1; outputNode >= numNeurons-outputPattern.length; outputNode--)
        {
            error[outputNode] = (outputPattern[inputIndex] - activity[outputNode]) 
                            * activity[outputNode] * (1.0d-activity[outputNode]);
            totalError += error[outputNode];
            inputIndex--;
        }
        int lastNonOutputUnit = outputNode;

        // ---------------------------------------------------
        // FIND ERROR FOR ALL NON-OUTPUT UNITS
        // ---------------------------------------------------
        for (int nonOutputNode = lastNonOutputUnit; nonOutputNode >= 0; nonOutputNode--)
        {
            error[nonOutputNode] = 0;
            for (int post = nonOutputNode + 1; post < numNeurons; post++)
            {
                if (connected[nonOutputNode][post]) 
                {
                    error[nonOutputNode] += (error[post] * weight[nonOutputNode][post]);
                }
            }
            error[nonOutputNode] *= (activity[nonOutputNode] * (1.0d-activity[nonOutputNode]));
        }
      
        // ---------------------------------------------------
        // UPDATE THE WEIGHTS
        // ---------------------------------------------------

        for (int pre = 0; pre < numNeurons; pre++)
        {
            for (int post = 0; post < numNeurons; post++)
            {
                if (connected[pre][post])
                {
                    deltaWij[pre][post] = (learningRate * activity[pre] * error[post]);
                }
            }
        }
        return deltaWij;
    }
    
    double sumSquaredError;
    public double getSumSquaredError()
    {
        return sumSquaredError;
    }
    
    public void backpropogate(double[] outputPattern)
    {
        sumSquaredError = 0;
        int j = 0;
        for (int i = numNeurons-outputPattern.length; i < numNeurons; i++)
        {
            sumSquaredError += (outputPattern[j] - activity[i]) * (outputPattern[j] - activity[i]);
            j++;
        }

        double [][] deltaWij = getDeltaWij(outputPattern);
        
        // ---------------------------------------------------
        // UPDATE THE WEIGHTS
        // ---------------------------------------------------

        for (int pre = 0; pre < numNeurons; pre++)
        {
            for (int post = 0; post < numNeurons; post++)
            {
                if (connected[pre][post])
                {
                    deltaWij[pre][post] = deltaWij[pre][post] + momentumTerm * (previousDeltaWij[pre][post]);
                    weight[pre][post] = weight[pre][post] + deltaWij[pre][post];
                    previousDeltaWij[pre][post] = deltaWij[pre][post];
                }
            }
        }
        
    }
    
    
    public boolean isConnected(int i, int j)
    {
        return connected[i + NUM_BIAS_NEURONS][j + NUM_BIAS_NEURONS];
    }
    
    public String getConnectivity()
    {
        String result = "";
        for (int i = 0; i < this.numNeurons; i++)
        {
            result = result + (i + "\t->");
            for (int j = 0; j < this.numNeurons; j++)
            {
                System.out.println("connected["+i+"]["+j+"] = "+connected[i][j]);
                if (connected[i][j])
                {
                    result = result + (" " + j + "@("+weight[i][j]+")");
                }
            }
            result = result + ("\n");
        }
        return result;
    }
    
    public double getTotalError()
    {
        return totalError;
    }
    
    public void setWeight(int i, int j, double value)
    {
        this.weight[i][j] = value;
    }    
    
    public double [][] getWeights()
    {
        return this.weight;
    }
    
    public void setWeights(double [][] newWeights)
    {
        for (int i = 0; i < numNeurons; i++)
        {
            for (int j = 0; j < numNeurons; j++)
            {
                this.weight[i][j] = newWeights[i][j];
            }
        }
    }

    // Set and get weights in a 1 dimensional vector
    public double [] getWeights1D()
    {
        double [] result = new double[numNeurons * numNeurons];
        int z = 0;

        for (int i = 0; i < numNeurons; i++)
        {
            for (int j = 0; j < numNeurons; j++)
            {
                result[z] = weight[i][j];
                z++;
            }
        }
        return result;
    }
    
    // Set and get weights in a 1 dimensional vector
    public void setWeights1D(double [] newWeights1D)
    {
        int z = 0;
        for (int i = 0; i < numNeurons; i++)
        {
            for (int j = 0; j < numNeurons; j++)
            {
                this.weight[i][j] = newWeights1D[z];
                z++;
            }
        }
    }
    
    public double getActivity(int i)
    {
        return activity[i + NUM_BIAS_NEURONS];
    }
    
    public double [] getActivityVector()
    {
        double [] result = new double[numNeurons];
        for (int i = 0; i < numNeurons; i++)
        {
            result[i] = activity[i];
        }
        return result;
    }
    
    public double getNetInput(int i)
    {
	return netInput[i + NUM_BIAS_NEURONS];
    }
    
    
    private int[] incrementAllBy(int [] a, int value)
    {
        int [] result = new int[a.length];
        for (int i = 0; i < a.length; i++)
        {
            result[i] = value + a[i];
        }
        return result;
    }

    public String getGraphviz(double threshold, double lineWidthScale)
    {
        String result = "digraph srn {\nranksep=5;\n";
        String arrowhead;
        int currentLayer = 0;
        
        result += "subgraph cluster_layer"+currentLayer+"{\n";
        for (int i = 0; i < this.numNeurons; i++)
        {
            if (!inLayer(currentLayer,i))
            {
                // result += " [style=invis]}\n";
                result += " }\n";                
                currentLayer++;
                result += "subgraph cluster_layer"+currentLayer+"{\n";
            }
            // if (!isContextNeuron(i))            result+= (i+"->");
            result+= (i+";");
        }
        result +="}\n";
        
        // result +=" [style=invis]}\n";
//        result +=" }\n";
        result = result.replaceAll("-> ", " ");
        result = result.replaceAll("->}", " }");
        
        
        for (int i = 0; i < this.numNeurons; i++)
        {
            for (int j = 0; j < this.numNeurons; j++)
            {
                if (connected[i][j])
                {
                    if (Math.abs(weight[i][j]) > threshold)
                    {
                        if (weight[i][j] > 0) { arrowhead = "inv"; } else { arrowhead = "tee"; }
                        result = result + i + " -> " + j + " ";
                        result += "[style=\"setlinewidth("+lineWidthScale*Math.abs(weight[i][j])+")\" arrowhead=\"" +
                                arrowhead +"\"";

                        //result += " label=\"" + weight[i][j] + "\"";
                        if (arrowhead.equals("inv"))
                        {
                            result += " color=\"red\"";
                        }
                        else
                        {
                            result += " color=\"blue\"";
                        }
                        result += "];\n";
                    }
                }
            }

        }

        
        result += "}\n";
        return result;
    }
    
    
}
