/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.neuralnet;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import net.pakl.neuralnet.SimpleRecurrentNet;

public class SRNInterface
extends Thread {
    ArrayList inputStrings = new ArrayList();
    ArrayList outputStrings = new ArrayList();
    public static boolean TRAIN_IDENTITY_DURING_QUESTION = true;
    double[] WORD_SEPARATION_VECTOR = new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0};
    double[] END_MESSAGE_VECTOR = new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0};
    double[][] ALPHABETIC_VECTOR = new double[][]{{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0}};
    private int beginOutput;
    private int endOutput;
    private int totalTrials = 0;
    private PrintStream output = System.out;
    private InputStream input = System.in;
    BufferedWriter errorFile;
    private boolean clearContextUnitsBetweenRuns = true;
    int batchesPerTrial = 1;
    boolean batchUpdate = true;
    private int hiddenLayerSize = 0;
    ArrayList historyLog = new ArrayList();
    private static final int QUESTION = 100;
    private static final int ANSWER = 101;

    public SRNInterface(PrintStream newOutput, InputStream newInput) throws Exception {
        this.output = newOutput;
        this.input = newInput;
        this.start();
    }

    public SRNInterface() throws Exception {
        this.start();
    }

    public void run() {
        try {
            this.errorFile = new BufferedWriter(new FileWriter("error.txt"));
            this.output.println("------------------------------------------------------------------------------");
            this.output.println("| Simple Recurrent Neural Network ASCII Interface Version 1.00               |");
            this.output.println("| Copyright (C) 2004 Patryk Laurent (patryk@cnbc.cmu.edu)                    |");
            this.output.println("|                                                                            |");
            this.output.println("|       o--< o--< o                                                          |");
            this.output.println("|       v    |                                                               |");
            this.output.println("|       \\____/                                                               |");
            this.output.println("------------------------------------------------------------------------------");
            BufferedReader in = new BufferedReader(new InputStreamReader(this.input));
            this.output.print("Random number seed: ");
            long seed = new Long(in.readLine());
            this.output.print("Please enter hidden layer size: ");
            this.hiddenLayerSize = new Integer(in.readLine());
            this.output.print("Please enter learning rate: ");
            double learningRate = new Double(in.readLine());
            this.output.print("Please enter momentum: ");
            double momentumTerm = new Double(in.readLine());
            int[] layerBreaks = new int[]{28, 28 + this.hiddenLayerSize, 56 + this.hiddenLayerSize};
            this.beginOutput = 28 + this.hiddenLayerSize;
            this.endOutput = 56 + this.hiddenLayerSize;
            SimpleRecurrentNet net = new SimpleRecurrentNet(seed, 28 + this.hiddenLayerSize + 28, learningRate, momentumTerm, layerBreaks);
            if (this.batchUpdate) {
                net.setBatchUpdate(true);
            }
            this.output.println("\nASCII-ready Simple Neural Network created and ready for training.\n");
            while (true) {
                String prompt;
                this.output.print("> ");
                String s = in.readLine();
                if (s.equals("learningRate") || s.equals("learn")) {
                    this.output.println("New learning rate:");
                    prompt = in.readLine();
                    net.setLearningRate(new Double(prompt));
                }
                if (s.equals("clearContext")) {
                    this.output.println("Clear context units between runs? (currently=" + this.clearContextUnitsBetweenRuns + "):");
                    prompt = in.readLine();
                    this.clearContextUnitsBetweenRuns = prompt.equalsIgnoreCase("true") || prompt.equalsIgnoreCase("yes");
                    this.output.println("clearContextUnitsBetweenRuns = " + this.clearContextUnitsBetweenRuns);
                }
                if (s.equals("momentum")) {
                    this.output.println("New momentum:");
                    prompt = in.readLine();
                    net.setMomentumTerm(new Double(prompt));
                }
                if (s.equals("batchesPerTrial")) {
                    this.output.println("New batchesPerTrial:");
                    prompt = in.readLine();
                    this.batchesPerTrial = new Integer(prompt);
                }
                if (s.equals("add")) {
                    this.output.println("Please enter a prompt sentence, then enter, then a response sentence.");
                    this.output.print("Prompt: ");
                    prompt = in.readLine();
                    this.output.print("Response: ");
                    String response = in.readLine();
                    this.inputStrings.add(prompt);
                    this.outputStrings.add(response);
                }
                if (s.equals("list")) {
                    for (int i = 0; i < this.inputStrings.size(); ++i) {
                        this.output.println(i + " " + this.inputStrings.get(i) + " -> " + this.outputStrings.get(i));
                    }
                }
                if (s.equals("train")) {
                    this.output.print("How many trials? ");
                    int trials = new Integer(in.readLine());
                    this.output.println(System.currentTimeMillis() + " Training network...");
                    this.trainNetwork(net, trials);
                    this.output.println(System.currentTimeMillis() + " Training done.");
                }
                if (s.equals("connectivity")) {
                    this.output.println(net.getConnectivity());
                }
                if (s.equals("graph")) {
                    this.output.print("Threshold?");
                    double threshold = new Double(in.readLine());
                    this.output.print("Linescale? ");
                    double linescale = new Double(in.readLine());
                    this.output.println(net.getGraphviz(threshold, linescale));
                }
                if (s.equals("load")) {
                    this.output.print("Filename? ");
                    String filename = in.readLine();
                    try {
                        net = (SimpleRecurrentNet)new ObjectInputStream(new FileInputStream(filename)).readObject();
                    }
                    catch (Exception e) {
                        this.output.print("Could not load: " + e.getClass() + " " + e.getMessage());
                    }
                }
                if (s.equals("save")) {
                    this.output.print("Filename? ");
                    String filename = in.readLine();
                    try {
                        new ObjectOutputStream(new FileOutputStream(filename)).writeObject(net);
                    }
                    catch (Exception e) {
                        this.output.print("Could not save: " + e.getClass() + " " + e.getMessage());
                    }
                }
                if (s.equals("test")) {
                    this.output.println("Testing network.  Enter blank line to end test.");
                    while (true) {
                        this.output.print("net> ");
                        String inputString = in.readLine();
                        if (inputString.equals("")) break;
                        this.testNetwork(net, inputString);
                    }
                }
                if (s.equals("help")) {
                    this.output.println("Commands: add list train help connectivity test learningRate momentum clearContext set history");
                }
                if (s.equals("set") || s.equals("history")) {
                    this.output.println("Clear context between runs? clearContextUnitsBetweenRuns = " + this.clearContextUnitsBetweenRuns);
                    this.output.println("Current Learning rate: " + net.getLearningRate());
                    this.output.println("Current Momentum Term: " + net.getMomentumTerm());
                    this.output.println("Total number of trials trained: " + this.totalTrials);
                    this.output.println("Stimulus batch exposure per trial: " + this.batchesPerTrial);
                    this.output.println("History:");
                    this.output.println(this.historyLog);
                }
                if (!s.equals("q") && !s.equals("exit")) continue;
                this.errorFile.close();
                System.exit(0);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            return;
        }
    }

    public void copyFromTo(ArrayList a, ArrayList b, int n) {
        for (int i = 0; i < n; ++i) {
            b.set(i, a.get(i));
        }
    }

    private void trainNetwork(SimpleRecurrentNet net, int trials) {
        ArrayList<ArrayList> inputStringVectors = new ArrayList<ArrayList>();
        ArrayList<ArrayList> outputStringVectors = new ArrayList<ArrayList>();
        ArrayList<Integer> exchangeLengths = new ArrayList<Integer>();
        this.historyLog.add("Net trained for " + trials + " trials on batches x " + this.batchesPerTrial + " " + this.inputStrings + "->" + this.outputStrings + ".\n  LearningRate = " + net.getLearningRate() + ", momentum = " + net.getMomentumTerm() + " hiddenLayer = " + this.hiddenLayerSize + "\n");
        for (int i = 0; i < this.inputStrings.size(); ++i) {
            String inputString = (String)this.inputStrings.get(i);
            String outputString = (String)this.outputStrings.get(i);
            int exchangeLength = inputString.length() + outputString.length();
            exchangeLengths.add(new Integer(exchangeLength));
            ArrayList inputVectorList = this.convertStringToVectors(inputString + ".", exchangeLength + 1, 100);
            ArrayList outputVectorList = this.convertStringToVectors(outputString + ".", exchangeLength + 1, 101);
            if (TRAIN_IDENTITY_DURING_QUESTION) {
                this.copyFromTo(inputVectorList, outputVectorList, inputString.length());
            }
            inputStringVectors.add(inputVectorList);
            outputStringVectors.add(outputVectorList);
        }
        double[] previousSSE = new double[this.inputStrings.size()];
        double sse = 0.0;
        for (int trial = 0; trial < trials; ++trial) {
            ++this.totalTrials;
            for (int i = 0; i < this.inputStrings.size(); ++i) {
                int exchangeLength = (Integer)exchangeLengths.get(i);
                if (this.clearContextUnitsBetweenRuns) {
                    net.clearContextUnits();
                }
                net.clearSSEMeasure();
                for (int withinBatch = 0; withinBatch < this.batchesPerTrial; ++withinBatch) {
                    for (int j = 0; j < exchangeLength + 1; ++j) {
                        net.feedforward((double[])((List)inputStringVectors.get(i)).get(j));
                        net.backpropogate((double[])((List)outputStringVectors.get(i)).get(j));
                        net.copyHiddenExcitationToContextUnits();
                    }
                    if (this.batchUpdate) {
                        net.batchUpdateWeights();
                    }
                    if (trial % 1000 != 0 || trial <= 0 || withinBatch != 0) continue;
                    sse = net.getSumSquaredError();
                    String s = "Training trial " + trial + " SumSq Error: " + sse + " diff = " + (sse - previousSSE[i]);
                    previousSSE[i] = sse;
                    try {
                        this.errorFile.write(s + "\n");
                        this.errorFile.flush();
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                    this.output.println(s);
                }
            }
            if (trial % 1000 != 0) continue;
            for (int testIndex = 0; testIndex < this.inputStrings.size(); ++testIndex) {
                this.output.print(this.inputStrings.get(testIndex) + "\t-> ");
                this.testNetwork(net, (String)this.inputStrings.get(testIndex));
            }
        }
        this.output.println("Total times through training set: " + this.totalTrials);
    }

    private void testNetwork(SimpleRecurrentNet net, String inputString) {
        int maxExchangeLength = 40;
        ArrayList inputVectors = this.convertStringToVectors(inputString + ".", maxExchangeLength + 1, 100);
        if (this.clearContextUnitsBetweenRuns) {
            net.clearContextUnits();
        }
        for (int j = 0; j < maxExchangeLength + 1; ++j) {
            net.feedforward((double[])inputVectors.get(j));
            this.output.print(this.convertNetOutputToAlpha(net, this.beginOutput, this.endOutput));
            net.copyHiddenExcitationToContextUnits();
        }
        this.output.println("");
    }

    private String convertNetOutputToAlpha(SimpleRecurrentNet net, int start, int cease) {
        int maxIndex = start;
        double maxOutput = net.getActivity(maxIndex);
        for (int i = start; i < cease; ++i) {
            if (!(net.getActivity(i) > maxOutput)) continue;
            maxIndex = i;
            maxOutput = net.getActivity(i);
        }
        if ((maxIndex -= start) == 26) {
            return " ";
        }
        if (maxIndex == 27) {
            return ".";
        }
        return new String("" + (char)(97 + maxIndex));
    }

    private void showVectorsInList(List doubleVectors) {
        for (int i = 0; i < doubleVectors.size(); ++i) {
            this.showVector((double[])doubleVectors.get(i));
            this.output.println("");
        }
    }

    private void showVector(double[] a) {
        for (int i = 0; i < a.length; ++i) {
            this.output.print(a[i] + " ");
        }
    }

    private ArrayList convertStringToVectors(String s, int exchangeLength, int questionOrAnswer) {
        int i;
        ArrayList<double[]> result = new ArrayList<double[]>();
        boolean messageIndex = false;
        int fillerRequired = exchangeLength - s.length();
        if (questionOrAnswer == 101) {
            for (i = 0; i < fillerRequired; ++i) {
                result.add(this.END_MESSAGE_VECTOR);
            }
        }
        for (i = 0; i < s.length(); ++i) {
            if (s.charAt(i) == '.') {
                result.add(this.END_MESSAGE_VECTOR);
                continue;
            }
            if (s.charAt(i) == ' ') {
                result.add(this.WORD_SEPARATION_VECTOR);
                continue;
            }
            result.add(this.ALPHABETIC_VECTOR[s.charAt(i) - 97]);
        }
        if (questionOrAnswer == 100) {
            for (i = 0; i < fillerRequired; ++i) {
                result.add(this.END_MESSAGE_VECTOR);
            }
        }
        return result;
    }

    public static void main(String[] args) throws Exception {
        new SRNInterface(System.out, System.in);
    }

    public void setOutput(PrintStream newOutput) {
        this.output = newOutput;
    }

    public void setInput(InputStream newInput) {
        this.input = newInput;
    }
}

