/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.rl;

import net.pakl.neuralnet.PerceptronBairdThesis;
import net.pakl.rl.HasVectorRepresentation;
import net.pakl.rl.State;
import net.pakl.rl.ValueFunction;
import net.pakl.rl.ValueFunctionResidualAlgorithmPerceptron;
import net.pakl.rl.World;

public class ValueFunctionResidualAlgorithmBairdPerceptron
extends ValueFunctionResidualAlgorithmPerceptron
implements ValueFunction {
    public ValueFunctionResidualAlgorithmBairdPerceptron() {
    }

    public ValueFunctionResidualAlgorithmBairdPerceptron(World w) {
        super(w);
    }

    protected void initializeNetworkForStatesOfType(HasVectorRepresentation thisStateObject) {
        System.out.println("Initialize network for ValueFunctionResidualAlgorithmBairdPerceptron");
        double[] thisState = thisStateObject.doubleRepresentation();
        int inputSize = thisState.length;
        int[] layerBreaks = new int[]{inputSize, inputSize + this.hiddenUnits, inputSize + this.hiddenUnits + 1};
        System.out.println("Input size v2: " + inputSize);
        System.out.println("Hidden units:" + this.hiddenUnits);
        this.net = new PerceptronBairdThesis(this.randomSeed, inputSize, this.hiddenUnits, this.learningRate, 0.0);
        System.out.println("We specified " + (inputSize + this.hiddenUnits + 1) + " neurons and learn rate is " + this.learningRate);
        double[][] currentWeights = this.net.getWeights();
        this.numPresynaptics = currentWeights.length;
        System.out.println("Perceptron really has (because of BIAS neurons) presynaptics with weights " + this.numPresynaptics);
        this.outputNeuron = inputSize + this.hiddenUnits;
        System.out.println("outputNeuron is neuron #" + this.outputNeuron);
        double[][] zeroes = new double[this.numPresynaptics][this.numPresynaptics];
        this.totalWeightChangeForEpoch = new double[this.numPresynaptics][this.numPresynaptics];
        System.out.println("residualWeighting = " + this.residualWeighting);
        System.out.println("[NOT USED] multiplyAndAdd = " + this.multiplyAndAdd);
    }

    protected void setValue(HasVectorRepresentation thisStateObject, HasVectorRepresentation nextStateObject, double newValue, double discountFactor) {
        if (this.net == null) {
            this.initializeNetworkForStatesOfType(thisStateObject);
        }
        double[] thisState = thisStateObject.doubleRepresentation();
        double[] nextState = nextStateObject.doubleRepresentation();
        double multiplyBy = -1.0 * this.learningRate * (newValue - this.getValue(thisStateObject));
        double[][] partialDerivVx = null;
        double[][] partialDerivVx2 = null;
        double[][] currentWeights = this.net.getWeights();
        this.net.feedforward(thisState);
        partialDerivVx = this.net.getOutputDerivativeAgainstWeights();
        this.net.feedforward(nextState);
        partialDerivVx2 = this.net.getOutputDerivativeAgainstWeights();
        double[][] weightChange = null;
        weightChange = this.RESIDUAL_ALGORITHM ? this.vmul(multiplyBy, this.vsub(this.vmul(this.residualWeighting * discountFactor, partialDerivVx2), partialDerivVx)) : this.vmul(multiplyBy, this.vsub(this.vmul(discountFactor, partialDerivVx2), partialDerivVx));
        if (this.INCREMENTAL_UPDATE) {
            this.net.setWeights(this.vadd(currentWeights, weightChange));
        } else {
            this.totalWeightChangeForEpoch = this.vadd(this.totalWeightChangeForEpoch, weightChange);
        }
    }

    protected double getValue(HasVectorRepresentation state) {
        if (this.world.isTerminalState((State)((Object)state))) {
            return this.valueOfTerminalStates;
        }
        if (this.net == null) {
            return 0.0;
        }
        double[] thisState = state.doubleRepresentation();
        this.net.feedforward(thisState);
        return this.net.getNetInput(this.outputNeuron);
    }
}

