package org.hipparchus.stat.projection;

import org.hipparchus.exception.MathIllegalStateException;
import org.hipparchus.linear.EigenDecompositionSymmetric;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.stat.LocalizedStatFormats;
import org.hipparchus.stat.StatUtils;
import org.hipparchus.stat.correlation.Covariance;
import org.hipparchus.stat.descriptive.moment.StandardDeviation;

/* loaded from: input_file:org/hipparchus/stat/projection/PCA.class */
public class PCA {
    private final int numC;
    private final boolean scale;
    private final boolean biasCorrection;
    private double[] center;
    private double[] std;
    private double[] eigenValues;
    private RealMatrix principalComponents;
    private final StandardDeviation sd;

    public PCA(int i, boolean z, boolean z2) {
        this.numC = i;
        this.scale = z;
        this.biasCorrection = z2;
        this.sd = z ? new StandardDeviation(z2) : null;
    }

    public PCA(int i) {
        this(i, false, true);
    }

    public int getNumComponents() {
        return this.numC;
    }

    public boolean isScale() {
        return this.scale;
    }

    public boolean isBiasCorrection() {
        return this.biasCorrection;
    }

    public double[] getVariance() {
        validateState("getVariance");
        return (double[]) this.eigenValues.clone();
    }

    public double[] getCenter() {
        validateState("getCenter");
        return (double[]) this.center.clone();
    }

    public double[][] getComponents() {
        validateState("getComponents");
        return this.principalComponents.getData();
    }

    public double[][] fitAndTransform(double[][] dArr) {
        this.center = null;
        RealMatrix normalizedMatrix = getNormalizedMatrix(dArr);
        calculatePrincipalComponents(normalizedMatrix);
        return normalizedMatrix.multiply(this.principalComponents).getData();
    }

    public double[][] transform(double[][] dArr) {
        validateState("transform");
        return getNormalizedMatrix(dArr).multiply(this.principalComponents).getData();
    }

    public PCA fit(double[][] dArr) {
        this.center = null;
        calculatePrincipalComponents(getNormalizedMatrix(dArr));
        return this;
    }

    private void validateState(String str) {
        if (this.center == null) {
            throw new MathIllegalStateException(LocalizedStatFormats.ILLEGAL_STATE_PCA, str);
        }
    }

    private void calculatePrincipalComponents(RealMatrix realMatrix) {
        EigenDecompositionSymmetric eigenDecompositionSymmetric = new EigenDecompositionSymmetric(new Covariance(realMatrix).getCovarianceMatrix());
        this.eigenValues = eigenDecompositionSymmetric.getEigenvalues();
        this.principalComponents = MatrixUtils.createRealMatrix(this.eigenValues.length, this.numC);
        for (int i = 0; i < this.numC; i++) {
            for (int i2 = 0; i2 < this.eigenValues.length; i2++) {
                this.principalComponents.setEntry(i2, i, eigenDecompositionSymmetric.getEigenvector(i).getEntry(i2));
            }
        }
    }

    private RealMatrix getNormalizedMatrix(double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        boolean z = this.center == null;
        if (z) {
            this.center = new double[length2];
            if (this.scale) {
                this.std = new double[length2];
            }
        }
        double[][] dArr2 = new double[length][length2];
        for (int i = 0; i < length2; i++) {
            if (z) {
                calculateNormalizeParameters(dArr, length, i);
            }
            for (int i2 = 0; i2 < length; i2++) {
                dArr2[i2][i] = dArr[i2][i] - this.center[i];
            }
            if (this.scale) {
                for (int i3 = 0; i3 < length; i3++) {
                    double[] dArr3 = dArr2[i3];
                    int i4 = i;
                    dArr3[i4] = dArr3[i4] / this.std[i];
                }
            }
        }
        return MatrixUtils.createRealMatrix(dArr2);
    }

    private void calculateNormalizeParameters(double[][] dArr, int i, int i2) {
        double[] dArr2 = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            dArr2[i3] = dArr[i3][i2];
        }
        this.center[i2] = StatUtils.mean(dArr2);
        if (this.scale) {
            this.std[i2] = this.sd.evaluate(dArr2, this.center[i2]);
        }
    }
}
