package net.pakl.rl;


import java.util.*;
import net.pakl.neuralnet.*;

public class ValueFunctionResidualAlgorithmPerceptron extends ValueFunctionPerceptron implements ValueFunction 
{
    // Our Value functions return fixed values for the terminal states, by definition.
    World world = null;
    double valueOfTerminalStates = 0;
    
    // The Residual Algorithm differentially weights the two partial derivatives terms.
    double residualWeighting = 0.4;
    // Neural Network (Function Approximator)
    Perceptron net = null;
    public int hiddenUnits = 5;
    public double learningRate = 0.01;
    int outputNeuron = 0;
    public double multiplyAndAdd = 10d;
    protected int numPresynaptics = 0;
    double [][] totalWeightChangeForEpoch = null;       // in case of epoch-wise (batch) mode
    long randomSeed = 12345;

    boolean RESIDUAL_ALGORITHM = true;
    public boolean INCREMENTAL_UPDATE = true;
    public boolean CROSS_PRODUCT = false;

    public boolean USE_REAL_VECTORS = true;
    
    protected String name = "valuefunction";
          
    public void setName(String name)
    {
        this.name = name;
    }
    public String getName()
    {
        return name;
    }    
    public void setValueOfOutOfBoundsStates(double newValueOfOutOfBoundsStates)
    {
        throw new RuntimeException("Not implemented.");
    }
    public void setAllowExpansionOfStateBounds(boolean trueOrFalse)
    {
        throw new RuntimeException("Not implemented.");
    }

    public void storeWeightChangesIfNonIncremental()
    {
        if (INCREMENTAL_UPDATE)
        {
        }
        else
        {
            System.out.println("Storing batch weight changes.");
            double [][] currentWeights = net.getWeights();
            net.setWeights(vadd(currentWeights, totalWeightChangeForEpoch));
            zeroOutTotalWeightChangeForEpoch();
        }
    }
    
    protected void zeroOutTotalWeightChangeForEpoch()
    {
            totalWeightChangeForEpoch =  new double[numPresynaptics][numPresynaptics];
    }
    
    public Perceptron getNetwork()
    {
        return this.net;
    }
    
    protected void initializeNetworkForStatesOfType(HasVectorRepresentation thisStateObject)
    {
            System.out.println("Initialize network for ValueFunctionResidualAlgorithmPerceptron");

            double [] thisState = thisStateObject.doubleRepresentation();
            
            int inputSize = thisState.length;
            int [] layerBreaks = {inputSize, inputSize+hiddenUnits, inputSize+hiddenUnits+1};
            System.out.println("Input size v2: " + inputSize);
            System.out.println("Hidden units:" + hiddenUnits);

            net = new Perceptron(randomSeed, inputSize+hiddenUnits+1, learningRate, 0, layerBreaks);
            // net.reinitializeRandomWeightsWith(0.1);
            System.out.println("We specified " + (inputSize+hiddenUnits+1) + " neurons and learn rate is " + learningRate);

            double [][] currentWeights = net.getWeights();
            numPresynaptics = currentWeights.length;
            System.out.println("Perceptron really has (because of BIAS neurons) presynaptics with weights " +numPresynaptics);
            outputNeuron = inputSize+hiddenUnits;
            System.out.println("outputNeuron is neuron #" + outputNeuron);
            double [][] zeroes = new double[numPresynaptics][numPresynaptics];
            // net.setWeights(zeroes);
            totalWeightChangeForEpoch = new double[numPresynaptics][numPresynaptics];

            System.out.println("residualWeighting = " + residualWeighting);
            System.out.println("multiplyAndAdd = " +multiplyAndAdd);
    }
    

    protected void setValue(HasVectorRepresentation thisStateObject, HasVectorRepresentation nextStateObject, double newValue, double discountFactor)
    {        
        if (net == null)
        {
            initializeNetworkForStatesOfType(thisStateObject);
        }
        
        double [] thisState = thisStateObject.doubleRepresentation();
        double [] nextState = nextStateObject.doubleRepresentation();
                
        double multiplyBy = -1.0d * learningRate * (forNetwork(newValue) - forNetwork(getValue(thisStateObject)));
        double [][] partialDerivVx = null;
        double [][] partialDerivVx2 = null;

        double [][] currentWeights = net.getWeights();

        net.feedforward(thisState);
        partialDerivVx = net.getOutputDerivativeAgainstWeights();
        net.feedforward(nextState);
        partialDerivVx2 = net.getOutputDerivativeAgainstWeights();

        
        double [][]weightChange = null;
        if (RESIDUAL_ALGORITHM)
        {
            weightChange = vmul(multiplyBy, vsub(vmul(residualWeighting*discountFactor, partialDerivVx2), partialDerivVx));
        }
        else
        {
            weightChange = vmul(multiplyBy, vsub(vmul(discountFactor, partialDerivVx2), partialDerivVx));
        }
        
        // To get incremental to work, do you need to randomly visit all the states? (not in sequence)
        if (INCREMENTAL_UPDATE)  
        {
            net.setWeights(vadd(currentWeights, weightChange));
        }
        else
        {
            totalWeightChangeForEpoch = vadd(totalWeightChangeForEpoch, weightChange);
        }
    }
    
    public double getValue(State state)
    {
        return getValue((HasVectorRepresentation) state);
    }

    protected double getValue(HasVectorRepresentation state)
    {
        if (world.isTerminalState((State)state))
        {
            return valueOfTerminalStates;
        }
        if (net == null)
        {
            return 0;
        }

        double [] thisState = state.doubleRepresentation();
        
        net.feedforward(thisState);
        return fromNetwork(net.getActivity(outputNeuron));
    }

    protected double sum(double []v)
    {
        double result = 0d;
        for (int i = 0; i < v.length; i++)
        {
                result = result + v[i];
        }        
        return result;
    }
    
    protected double sum(double [][]v)
    {
        double result = 0d;
        for (int i = 0; i < v.length; i++)
        {
            for (int j = 0; j < v[0].length; j++)
            {
                result = result + v[i][j];
            }
        }        
        return result;
    }
    
    protected double [] vsub(double [] v1, double [] v2)
    {
        double [] result = new double[v1.length];
        for (int i = 0; i < v1.length; i++)
        {
            result[i] = v1[i] - v2[i];
        }
        return result;
    }    
    
    protected double [][] vsub(double [][] v1, double [][] v2)
    {
        double [][] result = new double[v1.length][v1[0].length];
        for (int i = 0; i < v1.length; i++)
        {
            for (int j = 0; j < v1[0].length; j++)
            {
                result[i][j] = v1[i][j] - v2[i][j];
            }
        }
        return result;
    }

    protected double [] vadd(double [] v1, double [] v2)
    {
        double [] result = new double[v1.length];
        for (int i = 0; i < v1.length; i++)
        {
            result[i] = v1[i] + v2[i];
        }
        return result;
    }
    
    protected double [] intsToDoubles(int [] x)
    {
        double [] result = new double[x.length];
        for (int i = 0; i < x.length; i++)
        {
            result[i] = (int) x[i];
        }
        return result;
    }
    
    protected double [][] vadd(double [][] v1, double [][] v2)
    {
        double [][] result = new double[v1.length][v1[0].length];
        for (int i = 0; i < v1.length; i++)
        {
            for (int j = 0; j < v1[0].length; j++)
            {
                result[i][j] = v1[i][j] + v2[i][j];
            }
        }
        return result;
    }
    
    protected double [] vmul(double scalar, double [] vector)
    {
        double [] result = new double[vector.length];
        for (int i = 0; i < vector.length; i++)
        {
            result[i] = scalar * vector[i];
        }
        return result;
    }

    protected double [][] vmul(double scalar, double [][] vector)
    {
        double [][] result = new double[vector.length][vector[0].length];
        for (int i = 0; i < vector.length; i++)
        {
            for (int j = 0; j < vector[0].length; j++)
            {
                result[i][j] = scalar * vector[i][j];
            }
        }
        return result;
    }
    
    protected double dotProduct(double [] u,  double [] v)
    {
        if (u.length != v.length) 
        {
            throw new RuntimeException("Vectors must be of same length for dot product.");
        }
        double result = 0;
        for (int i = 0; i < u.length; i++)
        {
            result = result + u[i] * v[i];
        }
        return result;
    }
    
    public void setValue(State state, double newValue)
    {
        throw new RuntimeException("Not implemented in approximator -- pass in next state too.");
    }
    
    
    public ValueFunctionResidualAlgorithmPerceptron(World w) 
    {
        this.world = w;
    }
    
    public ValueFunctionResidualAlgorithmPerceptron() 
    {
    }

    public void setResidualWeighting(double newValue)
    {
        this.residualWeighting = newValue;
    }

	public void setHiddenUnits(int newValue)
	{
		if (net != null) { throw new RuntimeException("Sorry -- you cannot set this AFTER net already initialized."); }
		this.hiddenUnits = newValue;
	}    

	public void setLearningRate(double newValue)
	{
		this.learningRate = newValue;
	}
	public void setMaxMinValue(double newValue)
	{
		if (net != null) { throw new RuntimeException("Sorry -- you cannot set this AFTER net already initialized."); }
		this.multiplyAndAdd = newValue;
	}

    public String toText()
    {
        return "[ValueFunctionResidualAlgorithm - toText()]";
    }
    
    public void setValueOfTerminalStates(double newValue)
    {
        valueOfTerminalStates = newValue;
    }
    
    public void setValue(State thisState, State nextState, double newValue, double discountFactor)
    {
        setValue((HasVectorRepresentation) thisState, (HasVectorRepresentation) nextState, newValue, discountFactor);
    }

    public double forNetwork(double value)
    {
        return (multiplyAndAdd + value) / (2.0d * multiplyAndAdd);
    }

    public double fromNetwork(double value)
    {
        return (2.0d * multiplyAndAdd) * value - multiplyAndAdd;
    }
    
    public void setRandomSeed(long x)
    {
        this.randomSeed = x;
    }
}
