/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.rl;

import net.pakl.neuralnet.PerceptronPatrykThesis;
import net.pakl.rl.HasVectorRepresentation;
import net.pakl.rl.State;
import net.pakl.rl.ValueFunction;
import net.pakl.rl.ValueFunctionResidualAlgorithmPerceptron;
import net.pakl.rl.World;

public class ValueFunctionResidualAlgorithmPatrykNetwork
extends ValueFunctionResidualAlgorithmPerceptron
implements ValueFunction {
    protected int numContextUnits;
    public int timestepsToReflect = 4;
    public boolean clearBeforeTrajectories = true;
    public boolean clearBeforeEachState = true;

    public ValueFunctionResidualAlgorithmPatrykNetwork() {
    }

    public ValueFunctionResidualAlgorithmPatrykNetwork(World w) {
        super(w);
    }

    protected void initializeNetworkForStatesOfType(HasVectorRepresentation thisStateObject) {
        System.out.println("Initialize network for ValueFunctionResidualAlgorithmPatrykNetwork");
        double[] thisState = thisStateObject.doubleRepresentation();
        int inputSize = thisState.length;
        this.numContextUnits = this.hiddenUnits;
        int[] layerBreaks = new int[]{inputSize, inputSize + this.hiddenUnits, inputSize + this.hiddenUnits + 1};
        System.out.println("Input size: " + inputSize);
        System.out.println("Context units: " + this.numContextUnits);
        System.out.println("Hidden units:" + this.hiddenUnits);
        System.out.println("Parameters affecting Recurrent Computation:");
        System.out.println("  timestepsToReflect = " + this.timestepsToReflect);
        System.out.println("  clearBeforeTrajectories = " + this.clearBeforeTrajectories);
        System.out.println("  clearBeforeEachState = " + this.clearBeforeEachState);
        this.net = new PerceptronPatrykThesis(this.randomSeed, inputSize, this.numContextUnits, this.hiddenUnits, this.learningRate, 0.0);
        System.out.println("We specified " + (inputSize + this.hiddenUnits + 1) + " neurons and learn rate is " + this.learningRate);
        double[][] currentWeights = this.net.getWeights();
        this.numPresynaptics = currentWeights.length;
        System.out.println("Perceptron really has (because of BIAS neurons) presynaptics with weights " + this.numPresynaptics);
        this.outputNeuron = inputSize + this.hiddenUnits + this.numContextUnits;
        System.out.println("outputNeuron is neuron #" + this.outputNeuron);
        double[][] zeroes = new double[this.numPresynaptics][this.numPresynaptics];
        this.totalWeightChangeForEpoch = new double[this.numPresynaptics][this.numPresynaptics];
        System.out.println("residualWeighting = " + this.residualWeighting);
        System.out.println("[NOT USED] multiplyAndAdd = " + this.multiplyAndAdd);
    }

    protected void setValue(HasVectorRepresentation thisStateObject, HasVectorRepresentation nextStateObject, double newValue, double discountFactor) {
        if (this.net == null) {
            this.initializeNetworkForStatesOfType(thisStateObject);
        }
        double[] thisState = thisStateObject.doubleRepresentation();
        double[] nextState = nextStateObject.doubleRepresentation();
        double multiplyBy = -1.0 * this.learningRate * (newValue - this.getValue(thisStateObject));
        double[][] partialDerivVx = null;
        double[][] partialDerivVx2 = null;
        double[][] currentWeights = null;
        if (this.clearBeforeEachState) {
            ((PerceptronPatrykThesis)this.net).clearContextUnits();
        }
        for (int i = 0; i < this.timestepsToReflect; ++i) {
            currentWeights = this.net.getWeights();
            this.net.feedforward(thisState);
            partialDerivVx = this.net.getOutputDerivativeAgainstWeights();
            this.net.feedforward(nextState);
            partialDerivVx2 = this.net.getOutputDerivativeAgainstWeights();
            double[][] weightChange = null;
            weightChange = this.RESIDUAL_ALGORITHM ? this.vmul(multiplyBy, this.vsub(this.vmul(this.residualWeighting * discountFactor, partialDerivVx2), partialDerivVx)) : this.vmul(multiplyBy, this.vsub(this.vmul(discountFactor, partialDerivVx2), partialDerivVx));
            if (this.INCREMENTAL_UPDATE) {
                this.net.setWeights(this.vadd(currentWeights, weightChange));
            } else {
                this.totalWeightChangeForEpoch = this.vadd(this.totalWeightChangeForEpoch, weightChange);
            }
            ((PerceptronPatrykThesis)this.net).copyHiddenExcitationToContextUnits();
        }
        if (this.clearBeforeTrajectories && this.world.isTerminalState((State)((Object)thisStateObject))) {
            ((PerceptronPatrykThesis)this.net).clearContextUnits();
        }
    }

    protected double getValue(HasVectorRepresentation state) {
        if (this.world.isTerminalState((State)((Object)state))) {
            return this.valueOfTerminalStates;
        }
        if (this.net == null) {
            return 0.0;
        }
        double[] thisState = state.doubleRepresentation();
        if (this.clearBeforeEachState) {
            ((PerceptronPatrykThesis)this.net).clearContextUnits();
        }
        for (int i = 0; i < this.timestepsToReflect; ++i) {
            this.net.feedforward(thisState);
            ((PerceptronPatrykThesis)this.net).copyHiddenExcitationToContextUnits();
        }
        return this.net.getNetInput(this.outputNeuron);
    }
}

