/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.neuralnet;

import java.io.Serializable;
import java.util.Random;
import net.pakl.neuralnet.Perceptron;

public class SimpleRecurrentNet
implements Serializable {
    Random randomNumberGenerator;
    int numNeurons;
    double learningRate;
    double momentumTerm;
    double[] activity;
    double[] netInput;
    protected double[][] weight;
    boolean[][] connected;
    double[][] totalDeltaWij;
    boolean enableBatchUpdate = false;
    double initialRandomWeightFactor = 0.01;
    int[] layerBreaks;
    public static final int NUM_BIAS_NEURONS = 0;
    public static final double BIAS_ACTIVATION = 0.0;
    public static final long serialVersionUID = 2564253280646345832L;
    int numContextNeurons;
    double[][] previousDeltaWij;
    double sumSquaredError;

    public SimpleRecurrentNet() {
    }

    public SimpleRecurrentNet(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks, double newInitialRandomWeightFactor) {
        this.initialRandomWeightFactor = newInitialRandomWeightFactor;
        this.init(randomSeed, numNeurons, learningRate, momentumTerm, layerBreaks);
    }

    public SimpleRecurrentNet(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks) {
        this.init(randomSeed, numNeurons, learningRate, momentumTerm, layerBreaks);
    }

    public void init(long randomSeed, int numNeurons, double learningRate, double momentumTerm, int[] layerBreaks) {
        this.randomNumberGenerator = new Random(randomSeed);
        System.out.println("SimpleRecurrentNet 3.1 created with random seed " + randomSeed + ". First random number is: ");
        System.out.println(this.randomNumberGenerator.nextDouble());
        this.numContextNeurons = layerBreaks[1] - layerBreaks[0];
        System.out.println("numContextNeurons = " + this.numContextNeurons);
        this.netInput = new double[numNeurons + 0 + this.numContextNeurons];
        this.activity = new double[numNeurons + 0 + this.numContextNeurons];
        this.weight = new double[numNeurons + 0 + this.numContextNeurons][numNeurons + 0 + this.numContextNeurons];
        this.connected = new boolean[numNeurons + 0 + this.numContextNeurons][numNeurons + 0 + this.numContextNeurons];
        this.previousDeltaWij = new double[numNeurons + 0 + this.numContextNeurons][numNeurons + 0 + this.numContextNeurons];
        System.out.println("Layerbreaks initially are:" + this.vectorToText(layerBreaks));
        this.layerBreaks = this.incrementAllBy(layerBreaks, 0 + this.numContextNeurons);
        System.out.println("Layerbreaks final are:" + this.vectorToText(this.layerBreaks));
        this.numNeurons = numNeurons + 0 + this.numContextNeurons;
        this.learningRate = learningRate;
        this.momentumTerm = momentumTerm;
        this.setUpConnectivity();
        System.out.println("NUMNEURONS = " + this.numNeurons);
        System.out.println("totalDeltaWij initialized in case of batchUpdate...");
        this.totalDeltaWij = new double[this.numNeurons][this.numNeurons];
    }

    public void setBatchUpdate(boolean newValue) {
        System.out.print("enableBatchUpdate changed from " + this.enableBatchUpdate + " to ");
        this.enableBatchUpdate = newValue;
        System.out.println(this.enableBatchUpdate);
    }

    public void setUpConnectivity() {
        int destLayer;
        for (int sourceLayer = 0; sourceLayer < this.layerBreaks.length; ++sourceLayer) {
            destLayer = sourceLayer + 1;
            for (int pre = 0; pre < this.numNeurons; ++pre) {
                for (int post = 0; post < this.numNeurons; ++post) {
                    if (!this.inLayer(sourceLayer, pre) || !this.inLayer(destLayer, post)) continue;
                    this.connected[pre][post] = true;
                    this.weight[pre][post] = this.smallRandomWeight();
                }
            }
        }
        for (int pre = 0; pre < 0; ++pre) {
            for (destLayer = 1; destLayer < this.layerBreaks.length; ++destLayer) {
                for (int post = 0; post < this.numNeurons; ++post) {
                    if (!this.inLayer(destLayer, post)) continue;
                    this.connected[pre][post] = true;
                    this.weight[pre][post] = this.smallRandomWeight();
                }
            }
        }
    }

    private boolean inLayer(int layerNumber, int neuronNumber) {
        if (layerNumber >= this.layerBreaks.length) {
            return false;
        }
        return layerNumber == 0 ? neuronNumber >= 0 && neuronNumber < this.layerBreaks[0] : neuronNumber >= this.layerBreaks[layerNumber - 1] && neuronNumber < this.layerBreaks[layerNumber];
    }

    private boolean isContextNeuron(int neuronNumber) {
        return neuronNumber >= 0 && neuronNumber < 0 + this.numContextNeurons;
    }

    public void feedforward(double[] inputPattern) {
        for (int i = 0; i < 0; ++i) {
            this.netInput[i] = 0.0;
            this.activity[i] = 0.0;
        }
        int inputPatternIndex = 0;
        for (int post = 0 + this.numContextNeurons; post < this.numNeurons; ++post) {
            this.activity[post] = 0.0;
            this.netInput[post] = 0.0;
            if (inputPatternIndex < inputPattern.length) {
                this.netInput[post] = inputPattern[inputPatternIndex];
                this.activity[post] = inputPattern[inputPatternIndex];
                ++inputPatternIndex;
                continue;
            }
            for (int pre = 0; pre < this.numNeurons; ++pre) {
                if (!this.connected[pre][post]) continue;
                this.netInput[post] = this.netInput[post] + this.weight[pre][post] * this.activity[pre];
            }
            this.activity[post] = this.sigmoid(this.netInput[post]);
        }
    }

    public double[][] getDerivativeAgainstWeights(double[] targetPattern) {
        int lastNonOutputUnit;
        double[][] deltaWij = new double[this.numNeurons][this.numNeurons];
        double[] error = new double[this.numNeurons];
        int i = 0;
        int j = 0;
        j = targetPattern.length - 1;
        for (i = this.numNeurons - 1; i >= this.numNeurons - targetPattern.length; --i) {
            error[i] = (targetPattern[j] - this.activity[i]) * this.sigmoidDerivative(this.activity[i]);
            --j;
        }
        for (i = lastNonOutputUnit = i; i >= 0; --i) {
            error[i] = 0.0;
            for (j = i + 1; j < this.numNeurons; ++j) {
                if (!this.connected[i][j]) continue;
                int n = i;
                error[n] = error[n] + error[j] * this.weight[i][j];
            }
            error[i] = error[i] * this.sigmoidDerivative(this.activity[i]);
        }
        for (int pre = 0; pre < this.numNeurons; ++pre) {
            for (int post = 0; post < this.numNeurons; ++post) {
                if (!this.connected[pre][post]) continue;
                deltaWij[pre][post] = this.learningRate * this.activity[pre] * error[post];
            }
        }
        return deltaWij;
    }

    public double getSumSquaredError() {
        return this.sumSquaredError;
    }

    public void backpropogate(double[] targetPattern) {
        int j = 0;
        for (int i = this.numNeurons - targetPattern.length; i < this.numNeurons; ++i) {
            this.sumSquaredError += (targetPattern[j] - this.activity[i]) * (targetPattern[j] - this.activity[i]);
            ++j;
        }
        if (!this.enableBatchUpdate) {
            double[][] deltaWij = this.getDerivativeAgainstWeights(targetPattern);
            this.updateWeights(deltaWij);
        } else {
            this.sumInto(this.totalDeltaWij, this.getDerivativeAgainstWeights(targetPattern));
        }
        this.copyHiddenExcitationToContextUnits();
    }

    private void sumInto(double[][] a, double[][] b) {
        for (int i = 0; i < a.length; ++i) {
            for (int j = 0; j < a[0].length; ++j) {
                a[i][j] = a[i][j] + b[i][j];
            }
        }
    }

    private double[][] sum(double[][] a, double[][] b) {
        double[][] result = new double[a.length][a[0].length];
        for (int i = 0; i < a.length; ++i) {
            for (int j = 0; j < a[0].length; ++j) {
                result[i][j] = a[i][j] + b[i][j];
            }
        }
        return result;
    }

    public void batchUpdateWeights() {
        this.updateWeights(this.totalDeltaWij);
        this.totalDeltaWij = new double[this.numNeurons][this.numNeurons];
    }

    private void updateWeights(double[][] deltaWij) {
        for (int pre = 0; pre < this.numNeurons; ++pre) {
            for (int post = 0; post < this.numNeurons; ++post) {
                if (!this.connected[pre][post]) continue;
                deltaWij[pre][post] = deltaWij[pre][post] + this.momentumTerm * this.previousDeltaWij[pre][post];
                this.weight[pre][post] = this.weight[pre][post] + deltaWij[pre][post];
                this.previousDeltaWij[pre][post] = deltaWij[pre][post];
            }
        }
    }

    public void copyHiddenExcitationToContextUnits() {
        int hiddenLayerIndex = this.layerBreaks[0];
        for (int contextNeuron = 0; contextNeuron < 0 + this.numContextNeurons; ++contextNeuron) {
            this.activity[contextNeuron] = this.activity[hiddenLayerIndex];
            ++hiddenLayerIndex;
        }
    }

    public void clearSSEMeasure() {
        this.sumSquaredError = 0.0;
    }

    public void clearContextUnits() {
        for (int contextNeuron = 0; contextNeuron < 0 + this.numContextNeurons; ++contextNeuron) {
            this.activity[contextNeuron] = 0.0;
            this.netInput[contextNeuron] = 0.0;
        }
    }

    protected double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    protected double sigmoidDerivative(double x) {
        return x * (1.0 - x);
    }

    private double smallRandomWeight() {
        return this.initialRandomWeightFactor * (this.randomNumberGenerator.nextDouble() - 0.5);
    }

    public boolean isConnected(int i, int j) {
        return this.connected[i + 0][j + 0];
    }

    public void setWeight(int i, int j, double value) {
        this.weight[i][j] = value;
    }

    public void disconnect(int pre, int post) {
        this.connected[pre + 0 + this.numContextNeurons][post + 0 + this.numContextNeurons] = false;
    }

    public double getActivity(int i) {
        return this.activity[i + 0 + this.numContextNeurons];
    }

    public double[] getOutputLayerActivity() {
        int firstOutputNeuron = this.layerBreaks[this.layerBreaks.length - 1 - 1];
        double[] result = new double[this.numNeurons - firstOutputNeuron];
        int j = 0;
        for (int i = firstOutputNeuron; i < this.numNeurons; ++i) {
            result[j] = this.activity[i];
            ++j;
        }
        return result;
    }

    private int[] incrementAllBy(int[] a, int value) {
        int[] result = new int[a.length];
        for (int i = 0; i < a.length; ++i) {
            result[i] = value + a[i];
        }
        return result;
    }

    public void setLearningRate(double newRate) {
        this.learningRate = newRate;
    }

    public void setMomentumTerm(double newMomentum) {
        this.momentumTerm = newMomentum;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public double getMomentumTerm() {
        return this.momentumTerm;
    }

    public void copyWeightsTo(Perceptron p) {
        for (int i = 0; i < this.numNeurons; ++i) {
            for (int j = 0; j < this.numNeurons; ++j) {
                p.setWeight(i, j, this.weight[i][j]);
            }
        }
    }

    public String getTextRepresentation() {
        String result = "";
        for (int i = 0; i < this.numNeurons; ++i) {
            for (int j = 0; j < this.layerBreaks.length; ++j) {
                if (this.layerBreaks[j] != i) continue;
                result = result + "\n";
            }
            result = i < 0 ? result + "B " : (i < this.numContextNeurons + 0 ? result + "C " : result + "N ");
        }
        return result;
    }

    private String vectorToText(double[] a) {
        String result = "";
        for (int i = 0; i < a.length; ++i) {
            result = result + a[i] + " ";
        }
        return result + "\n";
    }

    private String vectorToText(int[] a) {
        String result = "";
        for (int i = 0; i < a.length; ++i) {
            result = result + a[i] + " ";
        }
        return result + "\n";
    }

    public String getConnectivity() {
        String result = "";
        for (int i = 0; i < this.numNeurons; ++i) {
            result = result + i + "\t->";
            for (int j = 0; j < this.numNeurons; ++j) {
                System.out.println("connected[" + i + "][" + j + "] = " + this.connected[i][j]);
                if (!this.connected[i][j]) continue;
                result = result + " " + j + "@(" + this.weight[i][j] + ")";
            }
            result = result + "\n";
        }
        return result;
    }

    public String getGraphviz(double threshold, double lineWidthScale) {
        int i;
        String result = "digraph srn {\nranksep=5;\n";
        int currentLayer = 0;
        result = result + "subgraph cluster_layer" + currentLayer + "{\n";
        for (i = 0; i < this.numNeurons; ++i) {
            if (!this.inLayer(currentLayer, i)) {
                result = result + " }\n";
                result = result + "subgraph cluster_layer" + ++currentLayer + "{\n";
            }
            if (this.isContextNeuron(i)) continue;
            result = result + i + ";";
        }
        result = result + "}\n";
        result = result + "subgraph cluster_contextLayer {\n";
        for (i = 0; i < this.numNeurons; ++i) {
            if (!this.isContextNeuron(i)) continue;
            result = result + i + ";";
        }
        result = result + " }\n";
        result = result.replaceAll("-> ", " ");
        result = result.replaceAll("->}", " }");
        for (i = 0; i < this.numNeurons; ++i) {
            for (int j = 0; j < this.numNeurons; ++j) {
                if (!this.connected[i][j] || !(Math.abs(this.weight[i][j]) > threshold)) continue;
                String arrowhead = this.weight[i][j] > 0.0 ? "inv" : "tee";
                result = result + i + " -> " + j + " ";
                result = result + "[style=\"setlinewidth(" + lineWidthScale * Math.abs(this.weight[i][j]) + ")\" arrowhead=\"" + arrowhead + "\"";
                result = arrowhead.equals("inv") ? result + " color=\"red\"" : result + " color=\"blue\"";
                result = result + "];\n";
            }
        }
        int hiddenLayerIndex = this.layerBreaks[0];
        for (int contextNeuron = 0; contextNeuron < 0 + this.numContextNeurons; ++contextNeuron) {
            result = result + hiddenLayerIndex + "->" + contextNeuron + "[constraint=false arrowhead=\"inv\" style=\"dotted\"];\n";
            ++hiddenLayerIndex;
        }
        result = result + "}\n";
        return result;
    }
}

