package net.pakl.rl;


import java.util.*;

public class ValueFunctionResidualAlgorithmLinear implements ValueFunction
{
    static final long serialVersionUID = 5035667110689014773L;
    java.util.Random randomNumberGenerator = new java.util.Random(123);
    World world = null;
    // ----------------------------------------------------------------------------------------------    
    // Our Value functions return fixed values for the terminal states, by definition.
    double valueOfTerminalStates = 0;
    // ----------------------------------------------------------------------------------------------    

    // ----------------------------------------------------------------------------------------------
    // The Residual Algorithm differentially weights the two partial derivatives terms.
    double residualWeighting = 0.4;
    // ----------------------------------------------------------------------------------------------

    // Function Approximator ------------------------------------------------------------------------
    double [] weights = null;    
    private boolean NORMALIZE_INPUT_VECTORS = false;
    private boolean NORMALIZE_INITIAL_WEIGHTS = false;
    private boolean NORMALIZE_WEIGHTS = false;

    public boolean CROSS_PRODUCT = false;
    private boolean DOUBLE_CROSS_PRODUCT = false;
    public boolean BINARY_COMPLEMENT = false;
    public boolean HAND_CRAFTED_VECTORS = false;

    public double learningRate = 0.01;
    int outputNeuron = 0;
    public double multiplyAndAdd = 10d;
    private int numPresynaptics = 0;

    private boolean USE_BIAS_UNIT = false;
    private double BIAS_UNIT_ACTIVITY = 1.0;
    public boolean RESIDUAL_ALGORITHM = true;
    
    public boolean INCREMENTAL_UPDATE = true;
    double [] totalWeightChange = null;       // in case of epoch-wise (batch) mode
    // ----------------------------------------------------------------------------------------------

    public boolean SCALE_OUTPUT_VALUE = false;
    
    private String name = "valuefunction";
    
    public void setWorld(World w)
    {
        this.world = w;
    }
    
    public void setName(String name)
    {
        this.name = name;
    }
    public String getName()
    {
        return name;
    }
    public void setValueOfOutOfBoundsStates(double newValueOfOutOfBoundsStates)
    {
        throw new RuntimeException("Not implemented.");
    }
    public void setAllowExpansionOfStateBounds(boolean trueOrFalse)
    {
        throw new RuntimeException("Not implemented.");
    }

    private double smallRandomWeight()
    {
        return (1.0d) * (randomNumberGenerator.nextDouble() - 0.5d);
    }
    
    private double feedforward(double [] vector)
    {
        return VectorTools.vdot(vector, weights);
    }
    
    private double feedforward(int [] vector)
    {
        return VectorTools.vdot(vector, weights);
    }
    
    private double [] derivativeOfFunction(double [] stateVector)
    {
        return stateVector;
    }
    private static double dot(double [] a, double [] b)
    {
        double result = 0.0d;
        for (int i = 0; i < a.length; i++)
        {
            result += a[i] * b[i];
        }
        return result;
    }
    
    private int numWeights = 0;
    
    private void initializeNetworkForVectorsOfType(int [] vector)
    {
        double [] temp = new double[vector.length];
        for (int i = 0; i < vector.length; i++)
        {
            temp[i] = (double) vector[i];
        }
        initializeNetworkForVectorsOfType(temp);
    }
    
    private void initializeNetworkForVectorsOfType(double [] thisStateObjectVector)
    {
        numWeights = thisStateObjectVector.length;
        weights = new double[numWeights];
        totalWeightChange = new double[numWeights];

        System.out.println("Weight vector is has " + numWeights + " elements");
        for (int i = 0; i < weights.length; i++)
        {
            weights[i] = smallRandomWeight();
        }
        if (NORMALIZE_INITIAL_WEIGHTS)
        {
            System.out.println(getClass() + " Normalizing initial weights");
            weights = VectorTools.normalize(weights);
        }
        displayStats();
    }

    private void setValue(HasVectorRepresentation thisStateObject, HasVectorRepresentation nextStateObject, double newValue, double discountFactor)
    {        
        int [] thisState = null;
        int [] nextState = null;

        thisState = thisStateObject.binaryRepresentation();
        nextState = nextStateObject.binaryRepresentation();
        
        if (CROSS_PRODUCT)
        {
            thisState = VectorTools.crossProduct(thisState);
            nextState = VectorTools.crossProduct(nextState);
            if (DOUBLE_CROSS_PRODUCT)
            {
                thisState = VectorTools.crossProduct(thisState);
                nextState = VectorTools.crossProduct(nextState);
            }
        }
        if (BINARY_COMPLEMENT)
        {
            thisState = VectorTools.binaryComplement(thisState);
            nextState = VectorTools.binaryComplement(nextState);
        }
        if (NORMALIZE_INPUT_VECTORS)
        {
            thisState = VectorTools.normalize(thisState);
            nextState = VectorTools.normalize(nextState);
        }
        
        //if (USE_BIAS_UNIT)
        //{
        //    double [] thisStateWithBias = new double[thisState.length + 1];
        //    double [] nextStateWithBias = new double[thisState.length + 1];
        //    for (int i = 1; i < thisStateWithBias.length; i++)
        //    {
        //        thisStateWithBias[i] = thisState[i-1];
        //        nextStateWithBias[i] = nextState[i-1];
        //    }
        //    thisStateWithBias[0] = BIAS_UNIT_ACTIVITY;
        //    nextStateWithBias[0] = BIAS_UNIT_ACTIVITY;
        //    thisState = thisStateWithBias;
        //    nextState = nextStateWithBias;
        //}            

        if (weights == null) initializeNetworkForVectorsOfType(thisState);
        
        double oldValue = getValue(thisStateObject);
        double [] weightChange = null;

        double teach = newValue - getValue(thisStateObject);
        
        if (RESIDUAL_ALGORITHM == false)
        {
            weightChange = VectorTools.vmul(learningRate * teach, thisState);  // Widrow-Hoff aka Adaline aka LMS rule.
        }
        if (RESIDUAL_ALGORITHM == true)
        {
            double [] residual = VectorTools.vsub(VectorTools.vmul(residualWeighting * discountFactor, nextState), thisState);
            weightChange = VectorTools.vmul(-1.0 * learningRate * teach, residual);
        }
        
        
        if (INCREMENTAL_UPDATE)
        {
            weights = VectorTools.vadd(weights, weightChange);
            totalIncrementalWeightChange += VectorTools.abssum(weightChange);

            if (NORMALIZE_WEIGHTS) weights = VectorTools.normalize(weights);
        }
        else
        {
            totalWeightChange = VectorTools.vadd(totalWeightChange, weightChange);
        }
    }
    
    double totalIncrementalWeightChange = 0;
    
    public double getTotalIncrementalWeightChangeAndReset()
    {
        double result = totalIncrementalWeightChange;
        totalIncrementalWeightChange = 0;
        return result;
    }
    
    public double getValue(State state)
    {
        return getValue((HasVectorRepresentation) state);
    }

    private double getValue(HasVectorRepresentation state)
    {
        if (world.isTerminalState((State)state)) return valueOfTerminalStates;
        if (weights == null) return 0;
        
        int [] stateVector = null;
        stateVector = state.binaryRepresentation();

        if (CROSS_PRODUCT)
        {
            stateVector = VectorTools.crossProduct(stateVector);
            if (DOUBLE_CROSS_PRODUCT)
            {
                stateVector = VectorTools.crossProduct(stateVector);
            }
        }
        if (BINARY_COMPLEMENT)
        {
            stateVector = VectorTools.binaryComplement(stateVector);
        }
        
        //if (USE_BIAS_UNIT)
        //{
        //    double [] stateVectorWithBias = new double[stateVector.length + 1];
        //    for (int i = 1; i < stateVectorWithBias.length; i++)
        //    {
        //        stateVectorWithBias[i] = stateVector[i-1];
        //    }
        //    stateVectorWithBias[0] = BIAS_UNIT_ACTIVITY;
        //    stateVector = stateVectorWithBias;
        //}
        
        if (NORMALIZE_INPUT_VECTORS) stateVector = VectorTools.normalize(stateVector);
        return this.feedforward(stateVector);
    }

    
    public void setValue(State state, double newValue)
    {
        this.setValue((HasVectorRepresentation) state, (HasVectorRepresentation) state, newValue, 0);
        return;
    }
    
    
    public ValueFunctionResidualAlgorithmLinear(World w) 
    {
        System.out.println("ValueFunctionResidualAlgorithmLinear initialized.");
        this.world = w;
    }
    
    public ValueFunctionResidualAlgorithmLinear() 
    {
        System.out.println("ValueFunctionResidualAlgorithmLinear -- no world passed in! No terminal states?");
    }

    public void setResidualWeighting(double newValue)
    {
        this.residualWeighting = newValue;
    }

    public void setHiddenUnits(int newValue)
    {
        System.err.println("There are no hidden units in this functional approximator -- set command ignored.");
    }    

    public void setLearningRate(double newValue)
    {
       this.learningRate = newValue;
    }

    public void setMaxMinValue(double newValue)
    {
       this.multiplyAndAdd = newValue;
    }

    public String toText()
    {
        return "[ValueFunctionResidualAlgorithm - toText()]";
    }
    
    public void setValueOfTerminalStates(double newValue)
    {
        valueOfTerminalStates = newValue;
    }
    
    public void setValue(State thisState, State nextState, double newValue, double discountFactor)
    {
        setValue((HasVectorRepresentation) thisState, (HasVectorRepresentation) nextState, newValue, discountFactor);
    }

    
    public void storeWeightChangesIfNonIncremental() 
    {    
        if (!INCREMENTAL_UPDATE)
        {
            System.out.println("Storing batch weight changes.");
            weights = VectorTools.vadd(weights, totalWeightChange);
            if (NORMALIZE_WEIGHTS) weights = VectorTools.normalize(weights);
            totalWeightChange = new double[numWeights];
        }
    }
    private void zeroOutTotalWeightChangeForEpoch() { }
    private void displayStats()
    {
        System.out.println("Residual Algorithm/Linear function approximator.\n NORMALIZE_INPUT_VECTORS="+NORMALIZE_INPUT_VECTORS+". Use_bias_unit="+USE_BIAS_UNIT);
        System.out.println(" residualWeighting = " + residualWeighting);
        System.out.println(" INCREMENTAL_UPDATE = "+INCREMENTAL_UPDATE);
        System.out.println(" multiplyAndAdd = " +multiplyAndAdd);
        System.out.println(" BINARY_COMPLEMENT = " + this.BINARY_COMPLEMENT);
        System.out.println(" CROSS_PRODUCT = " + this.CROSS_PRODUCT);
        System.out.println(" RESIDUAL_ALGORITHM = " + this.RESIDUAL_ALGORITHM);  
        System.out.println(" HAND_CRAFTED_VECTORS = " + this.HAND_CRAFTED_VECTORS);

    }
    
}
