/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.rl;

import java.util.Random;
import net.pakl.rl.HasVectorRepresentation;
import net.pakl.rl.State;
import net.pakl.rl.ValueFunction;
import net.pakl.rl.VectorTools;
import net.pakl.rl.World;

public class ValueFunctionResidualAlgorithmLinear
implements ValueFunction {
    static final long serialVersionUID = 5035667110689014773L;
    Random randomNumberGenerator = new Random(123L);
    World world = null;
    double valueOfTerminalStates = 0.0;
    double residualWeighting = 0.4;
    double[] weights = null;
    private boolean NORMALIZE_INPUT_VECTORS = false;
    private boolean NORMALIZE_INITIAL_WEIGHTS = false;
    private boolean NORMALIZE_WEIGHTS = false;
    public boolean CROSS_PRODUCT = false;
    private boolean DOUBLE_CROSS_PRODUCT = false;
    public boolean BINARY_COMPLEMENT = false;
    public boolean HAND_CRAFTED_VECTORS = false;
    public double learningRate = 0.01;
    int outputNeuron = 0;
    public double multiplyAndAdd = 10.0;
    private int numPresynaptics = 0;
    private boolean USE_BIAS_UNIT = false;
    private double BIAS_UNIT_ACTIVITY = 1.0;
    public boolean RESIDUAL_ALGORITHM = true;
    public boolean INCREMENTAL_UPDATE = true;
    double[] totalWeightChange = null;
    public boolean SCALE_OUTPUT_VALUE = false;
    private String name = "valuefunction";
    private int numWeights = 0;
    double totalIncrementalWeightChange = 0.0;

    public void setWorld(World w) {
        this.world = w;
    }

    public void setName(String name) {
        this.name = name;
    }

    public String getName() {
        return this.name;
    }

    public void setValueOfOutOfBoundsStates(double newValueOfOutOfBoundsStates) {
        throw new RuntimeException("Not implemented.");
    }

    public void setAllowExpansionOfStateBounds(boolean trueOrFalse) {
        throw new RuntimeException("Not implemented.");
    }

    private double smallRandomWeight() {
        return 1.0 * (this.randomNumberGenerator.nextDouble() - 0.5);
    }

    private double feedforward(double[] vector) {
        return VectorTools.vdot(vector, this.weights);
    }

    private double feedforward(int[] vector) {
        return VectorTools.vdot(vector, this.weights);
    }

    private double[] derivativeOfFunction(double[] stateVector) {
        return stateVector;
    }

    private static double dot(double[] a, double[] b) {
        double result = 0.0;
        for (int i = 0; i < a.length; ++i) {
            result += a[i] * b[i];
        }
        return result;
    }

    private void initializeNetworkForVectorsOfType(int[] vector) {
        double[] temp = new double[vector.length];
        for (int i = 0; i < vector.length; ++i) {
            temp[i] = vector[i];
        }
        this.initializeNetworkForVectorsOfType(temp);
    }

    private void initializeNetworkForVectorsOfType(double[] thisStateObjectVector) {
        this.numWeights = thisStateObjectVector.length;
        this.weights = new double[this.numWeights];
        this.totalWeightChange = new double[this.numWeights];
        System.out.println("Weight vector is has " + this.numWeights + " elements");
        for (int i = 0; i < this.weights.length; ++i) {
            this.weights[i] = this.smallRandomWeight();
        }
        if (this.NORMALIZE_INITIAL_WEIGHTS) {
            System.out.println(this.getClass() + " Normalizing initial weights");
            this.weights = VectorTools.normalize(this.weights);
        }
        this.displayStats();
    }

    private void setValue(HasVectorRepresentation thisStateObject, HasVectorRepresentation nextStateObject, double newValue, double discountFactor) {
        int[] thisState = null;
        int[] nextState = null;
        thisState = thisStateObject.binaryRepresentation();
        nextState = nextStateObject.binaryRepresentation();
        if (this.CROSS_PRODUCT) {
            thisState = VectorTools.crossProduct(thisState);
            nextState = VectorTools.crossProduct(nextState);
            if (this.DOUBLE_CROSS_PRODUCT) {
                thisState = VectorTools.crossProduct(thisState);
                nextState = VectorTools.crossProduct(nextState);
            }
        }
        if (this.BINARY_COMPLEMENT) {
            thisState = VectorTools.binaryComplement(thisState);
            nextState = VectorTools.binaryComplement(nextState);
        }
        if (this.NORMALIZE_INPUT_VECTORS) {
            thisState = VectorTools.normalize(thisState);
            nextState = VectorTools.normalize(nextState);
        }
        if (this.weights == null) {
            this.initializeNetworkForVectorsOfType(thisState);
        }
        double oldValue = this.getValue(thisStateObject);
        double[] weightChange = null;
        double teach = newValue - this.getValue(thisStateObject);
        if (!this.RESIDUAL_ALGORITHM) {
            weightChange = VectorTools.vmul(this.learningRate * teach, thisState);
        }
        if (this.RESIDUAL_ALGORITHM) {
            double[] residual = VectorTools.vsub(VectorTools.vmul(this.residualWeighting * discountFactor, nextState), thisState);
            weightChange = VectorTools.vmul(-1.0 * this.learningRate * teach, residual);
        }
        if (this.INCREMENTAL_UPDATE) {
            this.weights = VectorTools.vadd(this.weights, weightChange);
            this.totalIncrementalWeightChange += VectorTools.abssum(weightChange);
            if (this.NORMALIZE_WEIGHTS) {
                this.weights = VectorTools.normalize(this.weights);
            }
        } else {
            this.totalWeightChange = VectorTools.vadd(this.totalWeightChange, weightChange);
        }
    }

    public double getTotalIncrementalWeightChangeAndReset() {
        double result = this.totalIncrementalWeightChange;
        this.totalIncrementalWeightChange = 0.0;
        return result;
    }

    public double getValue(State state) {
        return this.getValue((HasVectorRepresentation)((Object)state));
    }

    private double getValue(HasVectorRepresentation state) {
        if (this.world.isTerminalState((State)((Object)state))) {
            return this.valueOfTerminalStates;
        }
        if (this.weights == null) {
            return 0.0;
        }
        int[] stateVector = null;
        stateVector = state.binaryRepresentation();
        if (this.CROSS_PRODUCT) {
            stateVector = VectorTools.crossProduct(stateVector);
            if (this.DOUBLE_CROSS_PRODUCT) {
                stateVector = VectorTools.crossProduct(stateVector);
            }
        }
        if (this.BINARY_COMPLEMENT) {
            stateVector = VectorTools.binaryComplement(stateVector);
        }
        if (this.NORMALIZE_INPUT_VECTORS) {
            stateVector = VectorTools.normalize(stateVector);
        }
        return this.feedforward(stateVector);
    }

    public void setValue(State state, double newValue) {
        this.setValue((HasVectorRepresentation)((Object)state), (HasVectorRepresentation)((Object)state), newValue, 0.0);
    }

    public ValueFunctionResidualAlgorithmLinear(World w) {
        System.out.println("ValueFunctionResidualAlgorithmLinear initialized.");
        this.world = w;
    }

    public ValueFunctionResidualAlgorithmLinear() {
        System.out.println("ValueFunctionResidualAlgorithmLinear -- no world passed in! No terminal states?");
    }

    public void setResidualWeighting(double newValue) {
        this.residualWeighting = newValue;
    }

    public void setHiddenUnits(int newValue) {
        System.err.println("There are no hidden units in this functional approximator -- set command ignored.");
    }

    public void setLearningRate(double newValue) {
        this.learningRate = newValue;
    }

    public void setMaxMinValue(double newValue) {
        this.multiplyAndAdd = newValue;
    }

    public String toText() {
        return "[ValueFunctionResidualAlgorithm - toText()]";
    }

    public void setValueOfTerminalStates(double newValue) {
        this.valueOfTerminalStates = newValue;
    }

    public void setValue(State thisState, State nextState, double newValue, double discountFactor) {
        this.setValue((HasVectorRepresentation)((Object)thisState), (HasVectorRepresentation)((Object)nextState), newValue, discountFactor);
    }

    public void storeWeightChangesIfNonIncremental() {
        if (!this.INCREMENTAL_UPDATE) {
            System.out.println("Storing batch weight changes.");
            this.weights = VectorTools.vadd(this.weights, this.totalWeightChange);
            if (this.NORMALIZE_WEIGHTS) {
                this.weights = VectorTools.normalize(this.weights);
            }
            this.totalWeightChange = new double[this.numWeights];
        }
    }

    private void zeroOutTotalWeightChangeForEpoch() {
    }

    private void displayStats() {
        System.out.println("Residual Algorithm/Linear function approximator.\n NORMALIZE_INPUT_VECTORS=" + this.NORMALIZE_INPUT_VECTORS + ". Use_bias_unit=" + this.USE_BIAS_UNIT);
        System.out.println(" residualWeighting = " + this.residualWeighting);
        System.out.println(" INCREMENTAL_UPDATE = " + this.INCREMENTAL_UPDATE);
        System.out.println(" multiplyAndAdd = " + this.multiplyAndAdd);
        System.out.println(" BINARY_COMPLEMENT = " + this.BINARY_COMPLEMENT);
        System.out.println(" CROSS_PRODUCT = " + this.CROSS_PRODUCT);
        System.out.println(" RESIDUAL_ALGORITHM = " + this.RESIDUAL_ALGORITHM);
        System.out.println(" HAND_CRAFTED_VECTORS = " + this.HAND_CRAFTED_VECTORS);
    }
}

