/*
 * Decompiled with CFR 0.152.
 */
package net.pakl.rl;

import net.pakl.neuralnet.Perceptron;
import net.pakl.rl.HasVectorRepresentation;
import net.pakl.rl.State;
import net.pakl.rl.ValueFunction;
import net.pakl.rl.World;

public class ValueFunctionPerceptron
implements ValueFunction {
    World world = null;
    double valueOfTerminalStates = 0.0;
    double residualWeighting = 0.4;
    Perceptron net = null;
    public int hiddenUnits = 5;
    public double learningRate = 0.01;
    public double momentum = 0.0;
    int outputNeuron = 0;
    public double multiplyAndAdd = 10.0;
    private int numPresynaptics = 0;
    double[][] totalWeightChangeForEpoch = null;
    long randomSeed = 12345L;
    boolean RESIDUAL_ALGORITHM = true;
    public boolean INCREMENTAL_UPDATE = true;
    public boolean CROSS_PRODUCT = false;
    public boolean USE_REAL_VECTORS = true;
    private String name = "valuefunction";

    public void setName(String name) {
        this.name = name;
    }

    public String getName() {
        return this.name;
    }

    public void setValueOfOutOfBoundsStates(double newValueOfOutOfBoundsStates) {
        throw new RuntimeException("Not implemented.");
    }

    public void setAllowExpansionOfStateBounds(boolean trueOrFalse) {
        throw new RuntimeException("Not implemented.");
    }

    public void setWorld(World w) {
        this.world = w;
    }

    public void storeWeightChangesIfNonIncremental() {
        if (!this.INCREMENTAL_UPDATE) {
            System.out.println("Storing batch weight changes.");
            double[][] currentWeights = this.net.getWeights();
            this.net.setWeights(this.vadd(currentWeights, this.totalWeightChangeForEpoch));
            this.zeroOutTotalWeightChangeForEpoch();
        }
    }

    private void zeroOutTotalWeightChangeForEpoch() {
        this.totalWeightChangeForEpoch = new double[this.numPresynaptics][this.numPresynaptics];
    }

    public Perceptron getNetwork() {
        return this.net;
    }

    private void initializeNetworkForStatesOfType(HasVectorRepresentation thisStateObject) {
        System.out.println("Initialize network for ValueFunctionPerceptron");
        double[] thisState = thisStateObject.doubleRepresentation();
        int inputSize = thisState.length;
        int[] layerBreaks = new int[]{inputSize, inputSize + this.hiddenUnits, inputSize + this.hiddenUnits + 1};
        System.out.println("Input size v2: " + inputSize);
        System.out.println("Hidden units:" + this.hiddenUnits);
        this.net = new Perceptron(this.randomSeed, inputSize + this.hiddenUnits + 1, this.learningRate, this.momentum, layerBreaks);
        this.net.reinitializeRandomWeightsWith(1.0);
        System.out.println("We specified " + (inputSize + this.hiddenUnits + 1) + " neurons and learn rate is " + this.learningRate);
        System.out.println("momentum is = " + this.momentum);
        double[][] currentWeights = this.net.getWeights();
        this.numPresynaptics = currentWeights.length;
        System.out.println("Perceptron really has (because of BIAS neurons) presynaptics with weights " + this.numPresynaptics);
        this.outputNeuron = inputSize + this.hiddenUnits;
        System.out.println("outputNeuron is neuron #" + this.outputNeuron);
        double[][] zeroes = new double[this.numPresynaptics][this.numPresynaptics];
        this.totalWeightChangeForEpoch = new double[this.numPresynaptics][this.numPresynaptics];
        System.out.println("residualWeighting = " + this.residualWeighting);
        System.out.println("multiplyAndAdd = " + this.multiplyAndAdd);
    }

    public double getValue(State state) {
        return this.getValue((HasVectorRepresentation)((Object)state));
    }

    private double getValue(HasVectorRepresentation state) {
        if (this.world.isTerminalState((State)((Object)state))) {
            return this.valueOfTerminalStates;
        }
        if (this.net == null) {
            return 0.0;
        }
        double[] thisState = state.doubleRepresentation();
        this.net.feedforward(thisState);
        return this.fromNetwork(this.net.getActivity(this.outputNeuron));
    }

    protected double sum(double[] v) {
        double result = 0.0;
        for (int i = 0; i < v.length; ++i) {
            result += v[i];
        }
        return result;
    }

    protected double sum(double[][] v) {
        double result = 0.0;
        for (int i = 0; i < v.length; ++i) {
            for (int j = 0; j < v[0].length; ++j) {
                result += v[i][j];
            }
        }
        return result;
    }

    protected double[] vsub(double[] v1, double[] v2) {
        double[] result = new double[v1.length];
        for (int i = 0; i < v1.length; ++i) {
            result[i] = v1[i] - v2[i];
        }
        return result;
    }

    protected double[][] vsub(double[][] v1, double[][] v2) {
        double[][] result = new double[v1.length][v1[0].length];
        for (int i = 0; i < v1.length; ++i) {
            for (int j = 0; j < v1[0].length; ++j) {
                result[i][j] = v1[i][j] - v2[i][j];
            }
        }
        return result;
    }

    protected double[] vadd(double[] v1, double[] v2) {
        double[] result = new double[v1.length];
        for (int i = 0; i < v1.length; ++i) {
            result[i] = v1[i] + v2[i];
        }
        return result;
    }

    protected double[] intsToDoubles(int[] x) {
        double[] result = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            result[i] = x[i];
        }
        return result;
    }

    protected double[][] vadd(double[][] v1, double[][] v2) {
        double[][] result = new double[v1.length][v1[0].length];
        for (int i = 0; i < v1.length; ++i) {
            for (int j = 0; j < v1[0].length; ++j) {
                result[i][j] = v1[i][j] + v2[i][j];
            }
        }
        return result;
    }

    protected double[] vmul(double scalar, double[] vector) {
        double[] result = new double[vector.length];
        for (int i = 0; i < vector.length; ++i) {
            result[i] = scalar * vector[i];
        }
        return result;
    }

    protected double[][] vmul(double scalar, double[][] vector) {
        double[][] result = new double[vector.length][vector[0].length];
        for (int i = 0; i < vector.length; ++i) {
            for (int j = 0; j < vector[0].length; ++j) {
                result[i][j] = scalar * vector[i][j];
            }
        }
        return result;
    }

    protected double dotProduct(double[] u, double[] v) {
        if (u.length != v.length) {
            throw new RuntimeException("Vectors must be of same length for dot product.");
        }
        double result = 0.0;
        for (int i = 0; i < u.length; ++i) {
            result += u[i] * v[i];
        }
        return result;
    }

    public void setValue(State state, double newValue) {
        if (this.net == null) {
            this.initializeNetworkForStatesOfType((HasVectorRepresentation)((Object)state));
        }
        double[] thisState = ((HasVectorRepresentation)((Object)state)).doubleRepresentation();
        this.net.feedforward(thisState);
        double[] desiredActivation = new double[]{this.forNetwork(newValue)};
        this.net.backpropogate(desiredActivation);
    }

    public ValueFunctionPerceptron(World w) {
        this.world = w;
    }

    public ValueFunctionPerceptron() {
    }

    public void setResidualWeighting(double newValue) {
        this.residualWeighting = newValue;
    }

    public void setHiddenUnits(int newValue) {
        if (this.net != null) {
            throw new RuntimeException("Sorry -- you cannot set this AFTER net already initialized.");
        }
        this.hiddenUnits = newValue;
    }

    public void setLearningRate(double newValue) {
        this.learningRate = newValue;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public void setMaxMinValue(double newValue) {
        if (this.net != null) {
            throw new RuntimeException("Sorry -- you cannot set this AFTER net already initialized.");
        }
        this.multiplyAndAdd = newValue;
    }

    public String toText() {
        return "[ValueFunctionResidualAlgorithm - toText()]";
    }

    public void setValueOfTerminalStates(double newValue) {
        this.valueOfTerminalStates = newValue;
    }

    public double forNetwork(double value) {
        return (this.multiplyAndAdd + value) / (2.0 * this.multiplyAndAdd);
    }

    public double fromNetwork(double value) {
        return 2.0 * this.multiplyAndAdd * value - this.multiplyAndAdd;
    }

    public void setRandomSeed(long x) {
        this.randomSeed = x;
    }
}

