package org.hipparchus.filtering.kalman.unscented;

import org.hipparchus.exception.LocalizedCoreFormats;
import org.hipparchus.exception.MathIllegalArgumentException;
import org.hipparchus.exception.MathRuntimeException;
import org.hipparchus.filtering.kalman.KalmanFilter;
import org.hipparchus.filtering.kalman.Measurement;
import org.hipparchus.filtering.kalman.ProcessEstimate;
import org.hipparchus.linear.MatrixDecomposer;
import org.hipparchus.linear.MatrixUtils;
import org.hipparchus.linear.RealMatrix;
import org.hipparchus.linear.RealVector;
import org.hipparchus.util.UnscentedTransformProvider;

/* loaded from: input_file:org/hipparchus/filtering/kalman/unscented/UnscentedKalmanFilter.class */
public class UnscentedKalmanFilter<T extends Measurement> implements KalmanFilter<T> {
    private UnscentedProcess<T> process;
    private ProcessEstimate predicted;
    private ProcessEstimate corrected;
    private final MatrixDecomposer decomposer;
    private final int n;
    private final UnscentedTransformProvider utProvider;

    public UnscentedKalmanFilter(MatrixDecomposer matrixDecomposer, UnscentedProcess<T> unscentedProcess, ProcessEstimate processEstimate, UnscentedTransformProvider unscentedTransformProvider) {
        this.decomposer = matrixDecomposer;
        this.process = unscentedProcess;
        this.corrected = processEstimate;
        this.n = this.corrected.getState().getDimension();
        this.utProvider = unscentedTransformProvider;
        if (this.n == 0) {
            throw new MathIllegalArgumentException(LocalizedCoreFormats.ZERO_STATE_SIZE, new Object[0]);
        }
    }

    @Override // org.hipparchus.filtering.kalman.KalmanFilter
    public ProcessEstimate estimationStep(T t) throws MathRuntimeException {
        return predictionAndCorrectionSteps(t, this.utProvider.unscentedTransform(this.corrected.getState(), this.corrected.getCovariance()));
    }

    public ProcessEstimate predictionAndCorrectionSteps(T t, RealVector[] realVectorArr) throws MathRuntimeException {
        UnscentedEvolution evolution = this.process.getEvolution(getCorrected().getTime(), realVectorArr, t);
        predict(evolution.getCurrentTime(), evolution.getCurrentStates(), evolution.getProcessNoiseMatrix());
        RealVector[] unscentedTransform = this.utProvider.unscentedTransform(this.predicted.getState(), this.predicted.getCovariance());
        RealVector[] predictedMeasurements = this.process.getPredictedMeasurements(unscentedTransform, t);
        RealVector unscentedMeanState = this.utProvider.getUnscentedMeanState(predictedMeasurements);
        RealMatrix computeInnovationCovarianceMatrix = computeInnovationCovarianceMatrix(predictedMeasurements, unscentedMeanState, t.getCovariance());
        correct(t, computeInnovationCovarianceMatrix, computeCrossCovarianceMatrix(unscentedTransform, this.predicted.getState(), predictedMeasurements, unscentedMeanState), computeInnovationCovarianceMatrix == null ? null : this.process.getInnovation(t, unscentedMeanState, this.predicted.getState(), computeInnovationCovarianceMatrix));
        return getCorrected();
    }

    private void predict(double d, RealVector[] realVectorArr, RealMatrix realMatrix) {
        RealVector unscentedMeanState = this.utProvider.getUnscentedMeanState(realVectorArr);
        this.predicted = new ProcessEstimate(d, unscentedMeanState, this.utProvider.getUnscentedCovariance(realVectorArr, unscentedMeanState).add(realMatrix));
        this.corrected = null;
    }

    private void correct(T t, RealMatrix realMatrix, RealMatrix realMatrix2, RealVector realVector) throws MathIllegalArgumentException {
        if (realVector == null) {
            this.corrected = this.predicted;
            return;
        }
        RealMatrix transpose = this.decomposer.decompose(realMatrix).solve(realMatrix2.transpose()).transpose();
        this.corrected = new ProcessEstimate(t.getTime(), this.predicted.getState().add(transpose.operate(realVector)), this.predicted.getCovariance().subtract(transpose.multiply(realMatrix).multiplyTransposed(transpose)), null, null, realMatrix, transpose);
    }

    @Override // org.hipparchus.filtering.kalman.KalmanFilter
    public ProcessEstimate getPredicted() {
        return this.predicted;
    }

    @Override // org.hipparchus.filtering.kalman.KalmanFilter
    public ProcessEstimate getCorrected() {
        return this.corrected;
    }

    public UnscentedTransformProvider getUnscentedTransformProvider() {
        return this.utProvider;
    }

    private RealMatrix computeInnovationCovarianceMatrix(RealVector[] realVectorArr, RealVector realVector, RealMatrix realMatrix) {
        if (realVector == null) {
            return null;
        }
        return this.utProvider.getUnscentedCovariance(realVectorArr, realVector).add(realMatrix);
    }

    private RealMatrix computeCrossCovarianceMatrix(RealVector[] realVectorArr, RealVector realVector, RealVector[] realVectorArr2, RealVector realVector2) {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(realVector.getDimension(), realVector2.getDimension());
        RealVector wc = this.utProvider.getWc();
        for (int i = 0; i <= 2 * this.n; i++) {
            createRealMatrix = createRealMatrix.add(outer(realVectorArr[i].subtract(realVector), realVectorArr2[i].subtract(realVector2)).scalarMultiply(wc.getEntry(i)));
        }
        return createRealMatrix;
    }

    private RealMatrix outer(RealVector realVector, RealVector realVector2) {
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(realVector.getDimension(), realVector2.getDimension());
        for (int i = 0; i < createRealMatrix.getRowDimension(); i++) {
            for (int i2 = 0; i2 < createRealMatrix.getColumnDimension(); i2++) {
                createRealMatrix.setEntry(i, i2, realVector.getEntry(i) * realVector2.getEntry(i2));
            }
        }
        return createRealMatrix;
    }
}
