package elki.clustering.em.models;

import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.index.tree.betula.features.ClusterFeature;
import elki.math.MathUtil;
import elki.math.linearalgebra.VMath;
import net.jafama.FastMath;

/* loaded from: input_file:elki/clustering/em/models/SphericalGaussianModel.class */
public class SphericalGaussianModel implements BetulaClusterModel {
    private static final double SINGULARITY_CHEAT = 1.0E-10d;
    double[] mean;
    double variance;
    double[] nmea;
    double logNorm;
    double logNormDet;
    double weight;
    double wsum;
    double priorvar;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SphericalGaussianModel(double d, double[] dArr) {
        this(d, dArr, 1.0d);
    }

    public SphericalGaussianModel(double d, double[] dArr, double d2) {
        this.weight = d;
        this.mean = dArr;
        this.logNorm = MathUtil.LOGTWOPI * dArr.length;
        this.logNormDet = FastMath.log(d) - (0.5d * this.logNorm);
        this.nmea = new double[dArr.length];
        this.variance = d2 > 0.0d ? d2 : SINGULARITY_CHEAT;
        this.priorvar = this.variance;
        this.wsum = 0.0d;
    }

    @Override // elki.clustering.em.models.EMClusterModel
    public void beginEStep() {
        this.wsum = 0.0d;
        this.variance = 0.0d;
    }

    @Override // elki.clustering.em.models.EMClusterModel
    public void updateE(NumberVector numberVector, double d) {
        if (!$assertionsDisabled && numberVector.getDimensionality() != this.mean.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (d < 0.0d || d >= Double.POSITIVE_INFINITY)) {
            throw new AssertionError(d);
        }
        if (d < Double.MIN_NORMAL) {
            return;
        }
        double d2 = this.wsum + d;
        double d3 = d / d2;
        for (int i = 0; i < this.mean.length; i++) {
            this.nmea[i] = this.mean[i] + ((numberVector.doubleValue(i) - this.mean[i]) * d3);
        }
        for (int i2 = 0; i2 < this.mean.length; i2++) {
            double doubleValue = numberVector.doubleValue(i2);
            this.variance += (doubleValue - this.nmea[i2]) * (doubleValue - this.mean[i2]) * d;
        }
        this.wsum = d2;
        System.arraycopy(this.nmea, 0, this.mean, 0, this.nmea.length);
    }

    @Override // elki.clustering.em.models.EMClusterModel
    public void finalizeEStep(double d, double d2) {
        int length = this.mean.length;
        this.weight = d;
        if (d2 > 0.0d && this.priorvar > 0.0d) {
            this.variance = (this.variance / length) + (d2 * this.priorvar);
            this.variance /= this.wsum + (d2 * (((length + 2.0d) + length) + 2.0d));
        } else if (this.wsum > 0.0d) {
            this.variance /= length * this.wsum;
        }
        this.logNormDet = FastMath.log(d) - (0.5d * (this.logNorm + (length * FastMath.log(MathUtil.max(this.variance, SINGULARITY_CHEAT)))));
        if (d2 <= 0.0d || this.priorvar != 0.0d) {
            return;
        }
        this.priorvar = this.variance;
    }

    public double mahalanobisDistance(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = dArr[i] - this.mean[i];
            d += (d2 / this.variance) * d2;
        }
        return d;
    }

    public double mahalanobisDistance(NumberVector numberVector) {
        double d = 0.0d;
        for (int i = 0; i < this.mean.length; i++) {
            double doubleValue = numberVector.doubleValue(i) - this.mean[i];
            d += (doubleValue / this.variance) * doubleValue;
        }
        return d;
    }

    @Override // elki.clustering.em.models.EMClusterModel
    public double estimateLogDensity(NumberVector numberVector) {
        return ((-0.5d) * mahalanobisDistance(numberVector)) + this.logNormDet;
    }

    @Override // elki.clustering.em.models.EMClusterModel
    public double getWeight() {
        return this.weight;
    }

    @Override // elki.clustering.em.models.EMClusterModel
    public void setWeight(double d) {
        this.weight = d;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // elki.clustering.em.models.EMClusterModel
    public EMModel finalizeCluster() {
        return new EMModel(this.mean, VMath.timesEquals(VMath.identity(this.nmea.length, this.nmea.length), this.variance));
    }

    @Override // elki.clustering.em.models.BetulaClusterModel
    public double estimateLogDensity(ClusterFeature clusterFeature) {
        int length = this.mean.length;
        double variance = (clusterFeature.variance() / length) + this.variance;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            double centroid = clusterFeature.centroid(i) - this.mean[i];
            d += (centroid / variance) * centroid;
        }
        return (-0.5d) * (d + this.logNorm + (FastMath.log(variance) * length));
    }

    @Override // elki.clustering.em.models.BetulaClusterModel
    public void updateE(ClusterFeature clusterFeature, double d) {
        if (!$assertionsDisabled && clusterFeature.getDimensionality() != this.mean.length) {
            throw new AssertionError();
        }
        double d2 = this.wsum + d;
        for (int i = 0; i < this.mean.length; i++) {
            this.nmea[i] = this.mean[i] + (((clusterFeature.centroid(i) - this.mean[i]) * d) / d2);
        }
        double length = this.variance / this.mean.length;
        this.variance = 0.0d;
        for (int i2 = 0; i2 < this.mean.length; i2++) {
            double centroid = clusterFeature.centroid(i2);
            this.variance += length + (d * clusterFeature.variance(i2)) + (d * (centroid - this.nmea[i2]) * (centroid - this.mean[i2]));
        }
        this.wsum = d2;
        System.arraycopy(this.nmea, 0, this.mean, 0, this.nmea.length);
    }

    static {
        $assertionsDisabled = !SphericalGaussianModel.class.desiredAssertionStatus();
    }
}
