package org.encog.mathutil.matrices.hessian;

import java.lang.reflect.Array;
import java.util.Arrays;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.flat.FlatNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;

/* loaded from: classes.dex */
public class ChainRuleWorker implements EngineTask {
    private double[] actual;
    private double error;
    private FlatNetwork flat;
    private double[] gradients;
    private double[][] hessian;
    private int high;
    private int[] layerCounts;
    private double[] layerDelta;
    private int[] layerFeedCounts;
    private int[] layerIndex;
    private double[] layerOutput;
    private double[] layerSums;
    private int low;
    private int outputNeuron;
    private final MLDataPair pair;
    private double[] totDeriv;
    private MLDataSet training;
    private int weightCount;
    private int[] weightIndex;
    private double[] weights;

    public ChainRuleWorker(FlatNetwork flatNetwork, MLDataSet mLDataSet, int i, int i2) {
        this.weightCount = flatNetwork.getWeights().length;
        int i3 = this.weightCount;
        this.hessian = (double[][]) Array.newInstance((Class<?>) double.class, i3, i3);
        this.training = mLDataSet;
        this.flat = flatNetwork;
        this.layerDelta = new double[this.flat.getLayerOutput().length];
        this.actual = new double[this.flat.getOutputCount()];
        int i4 = this.weightCount;
        this.totDeriv = new double[i4];
        this.gradients = new double[i4];
        this.weights = this.flat.getWeights();
        this.layerIndex = this.flat.getLayerIndex();
        this.layerCounts = this.flat.getLayerCounts();
        this.weightIndex = this.flat.getWeightIndex();
        this.layerOutput = this.flat.getLayerOutput();
        this.layerSums = this.flat.getLayerSums();
        this.layerFeedCounts = this.flat.getLayerFeedCounts();
        this.low = i;
        this.high = i2;
        this.pair = BasicMLDataPair.createPair(this.flat.getInputCount(), this.flat.getOutputCount());
    }

    private void process(int i, double[] dArr, double[] dArr2, double[] dArr3) {
        this.flat.compute(dArr2, this.actual);
        double d2 = dArr3[i] - this.actual[i];
        this.error = (d2 * d2) + this.error;
        for (int i2 = 0; i2 < this.actual.length; i2++) {
            if (i2 == i) {
                this.layerDelta[i2] = this.flat.getActivationFunctions()[0].derivativeFunction(this.layerSums[i2], this.layerOutput[i2]);
            } else {
                this.layerDelta[i2] = 0.0d;
            }
        }
        for (int beginTraining = this.flat.getBeginTraining(); beginTraining < this.flat.getEndTraining(); beginTraining++) {
            processLevel(beginTraining, dArr);
        }
        for (int i3 = 0; i3 < this.weights.length; i3++) {
            double[] dArr4 = this.gradients;
            dArr4[i3] = (dArr[i3] * d2) + dArr4[i3];
            double[] dArr5 = this.totDeriv;
            dArr5[i3] = dArr5[i3] + dArr[i3];
        }
        for (int i4 = 0; i4 < this.weightCount; i4++) {
            for (int i5 = 0; i5 < this.weightCount; i5++) {
                double[] dArr6 = this.hessian[i4];
                dArr6[i5] = (dArr[i4] * dArr[i5]) + dArr6[i5];
            }
        }
    }

    private void processLevel(int i, double[] dArr) {
        int[] iArr = this.layerIndex;
        int i2 = i + 1;
        int i3 = iArr[i2];
        int i4 = iArr[i];
        int i5 = this.layerCounts[i2];
        int i6 = this.layerFeedCounts[i];
        int i7 = this.weightIndex[i];
        ActivationFunction activationFunction = this.flat.getActivationFunctions()[i2];
        int i8 = i3;
        int i9 = 0;
        while (i9 < i5) {
            double d2 = this.layerOutput[i8];
            double d3 = 0.0d;
            int i10 = i4;
            int i11 = i7 + i9;
            int i12 = 0;
            while (i12 < i6) {
                double d4 = dArr[i11];
                double[] dArr2 = this.layerDelta;
                dArr[i11] = (dArr2[i10] * d2) + d4;
                d3 = (this.weights[i11] * dArr2[i10]) + d3;
                i11 += i5;
                i10++;
                i12++;
                i4 = i4;
            }
            this.layerDelta[i8] = activationFunction.derivativeFunction(this.layerSums[i8], this.layerOutput[i8]) * d3;
            i8++;
            i9++;
            i4 = i4;
        }
    }

    public double[] getDerivative() {
        return this.totDeriv;
    }

    public double getError() {
        return this.error;
    }

    public double[] getGradients() {
        return this.gradients;
    }

    public double[][] getHessian() {
        return this.hessian;
    }

    public FlatNetwork getNetwork() {
        return this.flat;
    }

    public int getOutputNeuron() {
        return this.outputNeuron;
    }

    @Override // org.encog.util.concurrency.EngineTask
    public void run() {
        this.error = 0.0d;
        EngineArray.fill(this.hessian, 0);
        Arrays.fill(this.totDeriv, 0.0d);
        Arrays.fill(this.gradients, 0.0d);
        double[] dArr = new double[this.weightCount];
        for (int i = this.low; i <= this.high; i++) {
            this.training.getRecord(i, this.pair);
            Arrays.fill(dArr, 0.0d);
            process(this.outputNeuron, dArr, this.pair.getInputArray(), this.pair.getIdealArray());
        }
    }

    public void setOutputNeuron(int i) {
        this.outputNeuron = i;
    }
}
