package org.opengion.penguin.math.statistics;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:WEB-INF/lib/penguin8.3.0.3.jar:org/opengion/penguin/math/statistics/HybsLogisticRegression.class */
public class HybsLogisticRegression {
    private final int n_N;
    private final int n_in;
    private final int n_out;
    private double[][] vW;
    private double[] vb;

    public HybsLogisticRegression(double[][] dArr, int[][] iArr, double d, int i, double d2) {
        this.n_N = dArr.length;
        this.n_in = dArr[0].length;
        this.n_out = iArr[0].length;
        this.vW = new double[this.n_out][this.n_in];
        this.vb = new double[this.n_out];
        Integer[] numArr = new Integer[this.n_N];
        for (int i2 = 0; i2 < this.n_N; i2++) {
            numArr[i2] = Integer.valueOf(i2);
        }
        List asList = Arrays.asList(numArr);
        double d3 = d;
        for (int i3 = 0; i3 < i; i3++) {
            Collections.shuffle(asList);
            for (int i4 = 0; i4 < this.n_N * d2; i4++) {
                int intValue = ((Integer) asList.get(i4)).intValue();
                train(dArr[intValue], iArr[intValue], d3);
            }
            d3 *= 0.95d;
        }
    }

    private double[] train(double[] dArr, int[] iArr, double d) {
        double[] dArr2 = new double[this.n_out];
        double[] dArr3 = new double[this.n_out];
        for (int i = 0; i < this.n_out; i++) {
            dArr2[i] = 0.0d;
            for (int i2 = 0; i2 < this.n_in; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (this.vW[i][i2] * dArr[i2]);
            }
            int i4 = i;
            dArr2[i4] = dArr2[i4] + this.vb[i];
        }
        softmax(dArr2);
        for (int i5 = 0; i5 < this.n_out; i5++) {
            dArr3[i5] = iArr[i5] - dArr2[i5];
            for (int i6 = 0; i6 < this.n_in; i6++) {
                double[] dArr4 = this.vW[i5];
                int i7 = i6;
                dArr4[i7] = dArr4[i7] + (((d * dArr3[i5]) * dArr[i6]) / this.n_N);
            }
            double[] dArr5 = this.vb;
            int i8 = i5;
            dArr5[i8] = dArr5[i8] + ((d * dArr3[i5]) / this.n_N);
        }
        return dArr3;
    }

    private void softmax(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.n_out; i++) {
            dArr[i] = Math.exp(dArr[i]);
            d += dArr[i];
        }
        for (int i2 = 0; i2 < this.n_out; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
    }

    public double[][] getW() {
        return this.vW;
    }

    public double[] getB() {
        return this.vb;
    }

    public double[] predict(double[] dArr) {
        double[] dArr2 = new double[this.n_out];
        for (int i = 0; i < this.n_out; i++) {
            dArr2[i] = 0.0d;
            for (int i2 = 0; i2 < this.n_in; i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (this.vW[i][i2] * dArr[i2]);
            }
            int i4 = i;
            dArr2[i4] = dArr2[i4] + this.vb[i];
        }
        softmax(dArr2);
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
    public static void main(String[] strArr) {
        ?? r0 = {new int[]{1, 0, 0}, new int[]{1, 0, 0}, new int[]{1, 0, 0}, new int[]{0, 1, 0}, new int[]{0, 1, 0}, new int[]{0, 1, 0}, new int[]{0, 0, 1}, new int[]{0, 0, 1}, new int[]{0, 0, 1}};
        double[] dArr = {new double[]{-2.5d, 2.0d}, new double[]{0.1d, -0.1d}, new double[]{1.5d, -2.5d}};
        double[][] dArr2 = new double[dArr.length][r0[0].length];
        HybsLogisticRegression hybsLogisticRegression = new HybsLogisticRegression(new double[]{new double[]{-2.0d, 2.0d}, new double[]{-2.1d, 1.9d}, new double[]{-1.8d, 2.1d}, new double[]{0.0d, 0.0d}, new double[]{0.2d, -0.2d}, new double[]{-0.1d, 0.1d}, new double[]{2.0d, -2.0d}, new double[]{2.2d, -2.1d}, new double[]{1.9d, -2.0d}}, r0, 0.1d, 500, 1.0d);
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = hybsLogisticRegression.predict(dArr[i]);
            System.out.print(Arrays.toString(dArr2[i]));
        }
    }
}
