
package net.pakl.neuralnet;

import java.util.*;
import java.io.*;

public class SRNInterface extends Thread
{
    
    ArrayList inputStrings = new ArrayList();
    ArrayList outputStrings = new ArrayList();
    public static boolean TRAIN_IDENTITY_DURING_QUESTION = true;
    double [] WORD_SEPARATION_VECTOR  = 
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0};
    double [] END_MESSAGE_VECTOR = 
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1};
    double [] [] ALPHABETIC_VECTOR =
    {
        {1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0}    ,
        {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0}    
    };
    private int beginOutput, endOutput;
    private int totalTrials = 0;
    private PrintStream output = System.out;
    private InputStream input = System.in;
    BufferedWriter errorFile;
    private boolean clearContextUnitsBetweenRuns = true;
	int batchesPerTrial = 1;
	boolean batchUpdate = true;
    
    public SRNInterface(PrintStream newOutput, InputStream newInput) throws Exception
    {
        this.output = newOutput;
        this.input = newInput;
        this.start();
    }
    public SRNInterface() throws Exception
    {
        this.start();
    }

	private int hiddenLayerSize = 0;
    
    public void run()
    {
        try
        {
	errorFile = new BufferedWriter(new FileWriter("error.txt"));
        output.println("------------------------------------------------------------------------------");
        output.println("| Simple Recurrent Neural Network ASCII Interface Version 1.00               |");
        output.println("| Copyright (C) 2004 Patryk Laurent (patryk@cnbc.cmu.edu)                    |");
        output.println("|                                                                            |");
        output.println("|       o--< o--< o                                                          |");
        output.println("|       v    |                                                               |");
        output.println("|       \\____/                                                               |");
        output.println("------------------------------------------------------------------------------");
        
        BufferedReader in = new BufferedReader(new InputStreamReader(input));

	output.print("Random number seed: ");
	long seed = new Long(in.readLine()).longValue();
        output.print("Please enter hidden layer size: ");
        hiddenLayerSize = new Integer(in.readLine()).intValue();
        output.print("Please enter learning rate: ");
        double learningRate = new Double(in.readLine()).doubleValue();
        output.print("Please enter momentum: ");
        double momentumTerm = new Double(in.readLine()).doubleValue();
        int [] layerBreaks = {28, 28 + hiddenLayerSize, 28 + 28 + hiddenLayerSize};
        this.beginOutput = 28 + hiddenLayerSize;
        this.endOutput = 28 + 28 + hiddenLayerSize;
        SimpleRecurrentNet net = new SimpleRecurrentNet(seed, 28 + hiddenLayerSize + 28, learningRate, momentumTerm, layerBreaks);
	if (batchUpdate) { net.setBatchUpdate(true); }
        output.println("\nASCII-ready Simple Neural Network created and ready for training.\n");
        
        while (true)
        {
            output.print("> ");
            String s = in.readLine();
	if (s.equals("learningRate") || s.equals("learn"))
	{
		output.println("New learning rate:");
                String prompt = in.readLine();
		net.setLearningRate(new Double(prompt).doubleValue());
	}
	if (s.equals("clearContext"))
	{
		output.println("Clear context units between runs? (currently="+clearContextUnitsBetweenRuns+"):");
		String prompt = in.readLine();
		if (prompt.equalsIgnoreCase("true") || prompt.equalsIgnoreCase("yes"))
		{
			clearContextUnitsBetweenRuns = true;
		}
		else
		{
			clearContextUnitsBetweenRuns = false;
		}
		output.println("clearContextUnitsBetweenRuns = "+clearContextUnitsBetweenRuns);
	}
	if (s.equals("momentum"))
	{
		output.println("New momentum:");
                String prompt = in.readLine();
		net.setMomentumTerm(new Double(prompt).doubleValue());
	}
	if (s.equals("batchesPerTrial"))
	{
		output.println("New batchesPerTrial:");
                String prompt = in.readLine();
		batchesPerTrial = (new Integer(prompt).intValue());
	}
            if (s.equals("add"))
            {
                output.println("Please enter a prompt sentence, then enter, then a response sentence.");
                output.print("Prompt: ");
                String prompt = in.readLine();
                output.print("Response: ");
                String response = in.readLine();
                inputStrings.add(prompt);
                outputStrings.add(response);
            }
            if (s.equals("list"))
            {
                for (int i = 0; i < inputStrings.size(); i++)
                {
                    output.println(i + " " + inputStrings.get(i) + " -> " + outputStrings.get(i));
                }
            }
            if (s.equals("train"))
            {
                output.print("How many trials? ");
                int trials = new Integer(in.readLine()).intValue();
                output.println(System.currentTimeMillis() + " Training network...");
                trainNetwork(net, trials);
                output.println(System.currentTimeMillis() + " Training done.");
            }
            if (s.equals("connectivity"))
            {
                output.println(net.getConnectivity());   
            }
            
            if (s.equals("graph"))
            {
                output.print("Threshold?");
                double threshold = new Double(in.readLine()).doubleValue();
                output.print("Linescale? ");
                double linescale = new Double(in.readLine()).doubleValue();
                output.println(net.getGraphviz(threshold, linescale));
            }
            if (s.equals("load"))
            {
                output.print("Filename? ");
                String filename = in.readLine();
                try
                {
                    net = (SimpleRecurrentNet) 
                        new ObjectInputStream(
                            new FileInputStream(filename)).readObject();
                }
                catch (Exception e)
                {
                    output.print("Could not load: "+e.getClass()+" " +e.getMessage());
                }
            }
            if (s.equals("save"))
            {
                output.print("Filename? ");
                String filename = in.readLine();
                try
                {
                   new ObjectOutputStream(
                           new FileOutputStream(filename)).writeObject(net);
                }
                catch (Exception e)
                {
                    output.print("Could not save: "+e.getClass()+" " +e.getMessage());
                }
            }
            
            if (s.equals("test"))
            {
                output.println("Testing network.  Enter blank line to end test.");
                while (true)
                {
                    output.print("net> ");
                    String inputString = in.readLine();
                    if (inputString.equals("")) break;
                    testNetwork(net, inputString);
                }
            }
            if (s.equals("help"))
            {
                output.println("Commands: add list train help connectivity test learningRate momentum clearContext set history");
            }
	    if (s.equals("set") || s.equals("history"))
	    {
		output.println("Clear context between runs? clearContextUnitsBetweenRuns = " + clearContextUnitsBetweenRuns);
		output.println("Current Learning rate: " + net.getLearningRate());
		output.println("Current Momentum Term: " + net.getMomentumTerm());
		output.println("Total number of trials trained: " + totalTrials);
		output.println("Stimulus batch exposure per trial: " + batchesPerTrial);
		output.println("History:");
		output.println(historyLog);
	    }
	    if (s.equals("q") || s.equals("exit"))
    	    {
	      errorFile.close();
	      System.exit(0);
	    }
        }  
        }
        catch (Exception e)
        {
          e.printStackTrace();
        }
    }



	public void copyFromTo(ArrayList a, ArrayList b, int n)
	{
		for (int i = 0; i < n; i++)
		{
			b.set(i, a.get(i));
		}
	}
	
    ArrayList historyLog = new ArrayList();
    private void trainNetwork(SimpleRecurrentNet net, int trials)
    {
        // hello.........
        // .....hi there.
        
        ArrayList inputStringVectors = new ArrayList();
        ArrayList outputStringVectors = new ArrayList();
        ArrayList exchangeLengths = new ArrayList();
	historyLog.add("Net trained for " + trials + " trials on batches x " + batchesPerTrial + " " + inputStrings + "->" + outputStrings+".\n  LearningRate = " + net.getLearningRate() + ", momentum = " + net.getMomentumTerm() + " hiddenLayer = "+hiddenLayerSize+"\n");
        for (int i = 0; i < inputStrings.size(); i++)
        {
            String inputString  = (String) inputStrings.get(i);
            String outputString = (String) outputStrings.get(i);
            int exchangeLength  = inputString.length() + outputString.length();
            exchangeLengths.add(new Integer(exchangeLength));
            ArrayList inputVectorList  = convertStringToVectors(inputString + ".", exchangeLength+1, QUESTION);
            ArrayList outputVectorList = convertStringToVectors(outputString + ".", exchangeLength+1, ANSWER);
	    if (TRAIN_IDENTITY_DURING_QUESTION)
	    {
		copyFromTo(inputVectorList, outputVectorList, inputString.length());
	    }
            inputStringVectors.add(inputVectorList);
            outputStringVectors.add(outputVectorList);
        }
        double [] previousSSE = new double[inputStrings.size()]; 
	double sse = 0d;
        for (int trial = 0; trial < trials; trial++)
        {
            totalTrials++;
            for (int i = 0; i < inputStrings.size(); i++)
            {
                int exchangeLength = ((Integer)exchangeLengths.get(i)).intValue();
                if (clearContextUnitsBetweenRuns) net.clearContextUnits();
		net.clearSSEMeasure();
		for (int withinBatch = 0; withinBatch < batchesPerTrial; withinBatch++)
		{
			for (int j = 0; j < exchangeLength + 1; j++)
			{
			    net.feedforward((double []) ((List)inputStringVectors.get(i)).get(j));
			    net.backpropogate((double []) ((List)outputStringVectors.get(i)).get(j));
			    net.copyHiddenExcitationToContextUnits();
			}
			if (batchUpdate) net.batchUpdateWeights();
			if ((trial % 1000 == 0) && (trial > 0) && (withinBatch == 0)) 
			{
			    sse = net.getSumSquaredError();
			    String s = "Training trial " + trial + " SumSq Error: " + sse
				+ " diff = "+(sse - previousSSE[i]);
			    previousSSE[i] = sse;
			    try {
			    errorFile.write(s + "\n");
			    errorFile.flush();
			    } catch (Exception e) { e.printStackTrace(); }
			    output.println(s);
			}
		}
            }
            if (trial % 1000 == 0)
            {
                for (int testIndex = 0; testIndex < inputStrings.size(); testIndex++)
                {
                    output.print(inputStrings.get(testIndex) + "\t-> ");
                    testNetwork(net, (String) inputStrings.get(testIndex));
                }
            }
        }
        output.println("Total times through training set: " + totalTrials);
    }
    
    private void testNetwork(SimpleRecurrentNet net, String inputString)
    {
        int maxExchangeLength = 40;
        List inputVectors = convertStringToVectors(inputString+".", maxExchangeLength+1, QUESTION);
        if (clearContextUnitsBetweenRuns) net.clearContextUnits();
        for (int j = 0; j < maxExchangeLength + 1; j++)
        {
            net.feedforward((double []) inputVectors.get(j));
            output.print(convertNetOutputToAlpha(net, beginOutput, endOutput));
            net.copyHiddenExcitationToContextUnits();
        }
        output.println("");
    }

    private String convertNetOutputToAlpha(SimpleRecurrentNet net, int start, int cease)
    {
        int maxIndex = start;
        double maxOutput = net.getActivity(maxIndex);
        for (int i = start; i < cease; i++)
        {
            if (net.getActivity(i) > maxOutput)
            {
                maxIndex = i;
                maxOutput = net.getActivity(i);
            }
        }
        maxIndex = maxIndex - start;
        if (maxIndex == 26) { return " "; }
        if (maxIndex == 27) { return "."; }
        return new String("" + (char) ('a' + maxIndex));
    }
    
    
    private static final int QUESTION = 100;
    private static final int ANSWER = 101;
 
    private void showVectorsInList(List doubleVectors)
    {
        for (int i = 0; i < doubleVectors.size(); i++)
        {
            showVector( (double[]) doubleVectors.get(i));
            output.println("");
        }
    }
    
    private void showVector(double [] a)
    {
        for (int i = 0; i < a.length; i++)
        {
            output.print(a[i]+" ");
        }
    }
    
    private ArrayList convertStringToVectors(String s, int exchangeLength, int questionOrAnswer)
    {
        ArrayList result = new ArrayList();
        int messageIndex = 0;
        
        int fillerRequired = exchangeLength - s.length();
        if (questionOrAnswer == ANSWER)
        {
            for (int i = 0; i < fillerRequired; i++)
            {
                result.add(END_MESSAGE_VECTOR);
            }
        }
        
        for (int i = 0; i < s.length(); i++)
        {
            if (s.charAt(i) == '.') 
                result.add(this.END_MESSAGE_VECTOR);
            else
            if (s.charAt(i) == ' ')
                result.add(this.WORD_SEPARATION_VECTOR);
            else
                result.add(ALPHABETIC_VECTOR[s.charAt(i) - 'a']);
        }
        
        if (questionOrAnswer == QUESTION)
        {
            for (int i = 0; i < fillerRequired; i++)
            {
                result.add(END_MESSAGE_VECTOR);
            }
        }
        return result;
    }
    
    public static void main(String args[]) throws Exception
    {
        new SRNInterface(System.out, System.in);
    }
    
    public void setOutput(PrintStream newOutput)
    {
        output = newOutput;
    }
    public void setInput(InputStream newInput)
    {
        input = newInput;
    }
    
}
