package elki.clustering.em;

import elki.clustering.ClusteringAlgorithm;
import elki.clustering.em.models.EMClusterModel;
import elki.clustering.em.models.EMClusterModelFactory;
import elki.clustering.em.models.MultivariateGaussianModelFactory;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.model.MeanModel;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.LongStatistic;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.Priority;
import elki.utilities.documentation.Description;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.References;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.jafama.FastMath;

@References({@Reference(authors = "A. P. Dempster, N. M. Laird, D. B. Rubin", title = "Maximum Likelihood from Incomplete Data via the EM algorithm", booktitle = "Journal of the Royal Statistical Society, Series B, 39(1)", url = "http://www.jstor.org/stable/2984875", bibkey = "journals/jroyastatsocise2/DempsterLR77"), @Reference(title = "Bayesian Regularization for Normal Mixture Estimation and Model-Based Clustering", authors = "C. Fraley, A. E. Raftery", booktitle = "J. Classification 24(2)", url = "https://doi.org/10.1007/s00357-007-0004-5", bibkey = "DBLP:journals/classification/FraleyR07")})
@Description("Cluster data via Gaussian mixture modeling and the EM algorithm")
@Title("EM-Clustering: Clustering by Expectation Maximization")
@Priority(200)
/* loaded from: input_file:elki/clustering/em/EM.class */
public class EM<O, M extends MeanModel> implements ClusteringAlgorithm<Clustering<M>> {
    protected int k;
    protected double delta;
    protected EMClusterModelFactory<? super O, M> mfactory;
    protected int miniter;
    protected int maxiter;
    protected double prior;
    protected boolean soft;
    protected static final double MIN_LOGLIKELIHOOD = -100000.0d;
    private static final Logging LOG = Logging.getLogger(EM.class);
    private static final String KEY = EM.class.getName();
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation<>(double[].class);

    /* loaded from: input_file:elki/clustering/em/EM$Par.class */
    public static class Par<O, M extends MeanModel> implements Parameterizer {
        public static final OptionID K_ID = new OptionID("em.k", "The number of clusters to find.");
        public static final OptionID DELTA_ID = new OptionID("em.delta", "The termination criterion for maximization of E(M): E(M) - E(M') < em.delta");
        public static final OptionID MODEL_ID = new OptionID("em.model", "Model factory.");
        public static final OptionID MINITER_ID = new OptionID("em.miniter", "Minimum number of iterations.");
        public static final OptionID MAXITER_ID = new OptionID("em.maxiter", "Maximum number of iterations.");
        public static final OptionID PRIOR_ID = new OptionID("em.map.prior", "Regularization factor for MAP estimation.");
        public static final OptionID SOFT_ID = new OptionID("em.soft", "Retain soft assignment of clusters.");
        protected int k;
        protected double delta;
        protected EMClusterModelFactory<O, M> mfactory;
        protected int miniter = 1;
        protected int maxiter = -1;
        double prior = 0.0d;
        boolean soft = false;

        public void configure(Parameterization parameterization) {
            new IntParameter(K_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).grab(parameterization, i -> {
                this.k = i;
            });
            new ObjectParameter(MODEL_ID, EMClusterModelFactory.class, MultivariateGaussianModelFactory.class).grab(parameterization, eMClusterModelFactory -> {
                this.mfactory = eMClusterModelFactory;
            });
            new DoubleParameter(DELTA_ID, 1.0E-7d).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE).grab(parameterization, d -> {
                this.delta = d;
            });
            new IntParameter(MINITER_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT).setOptional(true).grab(parameterization, i2 -> {
                this.miniter = i2;
            });
            new IntParameter(MAXITER_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT).setOptional(true).grab(parameterization, i3 -> {
                this.maxiter = i3;
            });
            new DoubleParameter(PRIOR_ID).setOptional(true).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).grab(parameterization, d2 -> {
                this.prior = d2;
            });
            new Flag(SOFT_ID).grab(parameterization, z -> {
                this.soft = z;
            });
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public EM<O, M> m100make() {
            return new EM<>(this.k, this.delta, this.mfactory, this.miniter, this.maxiter, this.prior, this.soft);
        }
    }

    public EM(int i, double d, EMClusterModelFactory<? super O, M> eMClusterModelFactory) {
        this(i, d, eMClusterModelFactory, -1, 0.0d, false);
    }

    public EM(int i, double d, EMClusterModelFactory<? super O, M> eMClusterModelFactory, int i2, boolean z) {
        this(i, d, eMClusterModelFactory, i2, 0.0d, z);
    }

    public EM(int i, double d, EMClusterModelFactory<? super O, M> eMClusterModelFactory, int i2, double d2, boolean z) {
        this(i, d, eMClusterModelFactory, 1, i2, d2, z);
    }

    public EM(int i, double d, EMClusterModelFactory<? super O, M> eMClusterModelFactory, int i2, int i3, double d2, boolean z) {
        this.prior = 0.0d;
        this.k = i;
        this.delta = d;
        this.mfactory = eMClusterModelFactory;
        this.miniter = i2;
        this.maxiter = i3;
        this.prior = d2;
        this.soft = z;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    public Clustering<M> run(Relation<O> relation) {
        if (relation.size() == 0) {
            throw new IllegalArgumentException("database empty: must contain elements");
        }
        List<? extends EMClusterModel<? super O, M>> buildInitialModels = this.mfactory.buildInitialModels(relation, this.k);
        WritableDataStore makeStorage = DataStoreUtil.makeStorage(relation.getDBIDs(), 10, double[].class);
        double assignProbabilitiesToInstances = assignProbabilitiesToInstances(relation, buildInitialModels, makeStorage, null);
        DoubleStatistic doubleStatistic = new DoubleStatistic(getClass().getName() + ".loglikelihood");
        LOG.statistics(doubleStatistic.setDouble(assignProbabilitiesToInstances));
        int i = 0;
        int i2 = 0;
        double d = Double.NEGATIVE_INFINITY;
        while (true) {
            i++;
            if (i >= this.maxiter && this.maxiter >= 0) {
                break;
            }
            double d2 = assignProbabilitiesToInstances;
            recomputeCovarianceMatrices(relation, makeStorage, buildInitialModels, this.prior);
            assignProbabilitiesToInstances = assignProbabilitiesToInstances(relation, buildInitialModels, makeStorage, null);
            LOG.statistics(doubleStatistic.setDouble(assignProbabilitiesToInstances));
            if (assignProbabilitiesToInstances - d > this.delta) {
                i2 = i;
                d = assignProbabilitiesToInstances;
            }
            if (i >= this.miniter && (Math.abs(assignProbabilitiesToInstances - d2) <= this.delta || i2 < (i >> 1))) {
                break;
            }
        }
        LOG.statistics(new LongStatistic(KEY + ".iterations", i));
        ArrayList arrayList = new ArrayList(this.k);
        for (int i3 = 0; i3 < this.k; i3++) {
            arrayList.add(DBIDUtil.newArray());
        }
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            ((ModifiableDBIDs) arrayList.get(VMath.argmax((double[]) makeStorage.get(iterDBIDs)))).add(iterDBIDs);
            iterDBIDs.advance();
        }
        Clustering<M> clustering = new Clustering<>();
        Metadata.of(clustering).setLongName("EM Clustering");
        for (int i4 = 0; i4 < this.k; i4++) {
            clustering.addToplevelCluster(new Cluster<>((DBIDs) arrayList.get(i4), buildInitialModels.get(i4).finalizeCluster()));
        }
        if (this.soft) {
            Metadata.hierarchyOf(clustering).addChild(new MaterializedRelation("EM Cluster Probabilities", SOFT_TYPE, relation.getDBIDs(), makeStorage));
        } else {
            makeStorage.destroy();
        }
        return clustering;
    }

    public static <O> void recomputeCovarianceMatrices(Relation<? extends O> relation, WritableDataStore<double[]> writableDataStore, List<? extends EMClusterModel<? super O, ?>> list, double d) {
        double d2;
        double size;
        int size2 = list.size();
        boolean z = false;
        for (EMClusterModel<? super O, ?> eMClusterModel : list) {
            eMClusterModel.beginEStep();
            z |= eMClusterModel.needsTwoPass();
        }
        if (z) {
            DBIDIter iterDBIDs = relation.iterDBIDs();
            while (iterDBIDs.valid()) {
                double[] dArr = (double[]) writableDataStore.get(iterDBIDs);
                Object obj = relation.get(iterDBIDs);
                for (int i = 0; i < dArr.length; i++) {
                    double d3 = dArr[i];
                    if (d3 > 1.0E-10d) {
                        list.get(i).firstPassE(obj, d3);
                    }
                }
                iterDBIDs.advance();
            }
            Iterator<? extends EMClusterModel<? super O, ?>> it = list.iterator();
            while (it.hasNext()) {
                it.next().finalizeFirstPassE();
            }
        }
        double[] dArr2 = new double[size2];
        DBIDIter iterDBIDs2 = relation.iterDBIDs();
        while (iterDBIDs2.valid()) {
            double[] dArr3 = (double[]) writableDataStore.get(iterDBIDs2);
            Object obj2 = relation.get(iterDBIDs2);
            for (int i2 = 0; i2 < dArr3.length; i2++) {
                double d4 = dArr3[i2];
                if (d4 > 1.0E-10d) {
                    list.get(i2).updateE(obj2, d4);
                }
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + d4;
            }
            iterDBIDs2.advance();
        }
        for (int i4 = 0; i4 < list.size(); i4++) {
            if (d <= 0.0d) {
                d2 = dArr2[i4];
                size = relation.size();
            } else {
                d2 = (dArr2[i4] + d) - 1.0d;
                size = (relation.size() + (d * size2)) - size2;
            }
            list.get(i4).finalizeEStep(d2 / size, d);
        }
    }

    public static <O> double assignProbabilitiesToInstances(Relation<? extends O> relation, List<? extends EMClusterModel<? super O, ?>> list, WritableDataStore<double[]> writableDataStore, WritableDoubleDataStore writableDoubleDataStore) {
        int size = list.size();
        double d = 0.0d;
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            Object obj = relation.get(iterDBIDs);
            double[] dArr = new double[size];
            for (int i = 0; i < size; i++) {
                double estimateLogDensity = list.get(i).estimateLogDensity(obj);
                dArr[i] = estimateLogDensity > MIN_LOGLIKELIHOOD ? estimateLogDensity : MIN_LOGLIKELIHOOD;
            }
            double logSumExp = logSumExp(dArr);
            for (int i2 = 0; i2 < size; i2++) {
                dArr[i2] = FastMath.exp(dArr[i2] - logSumExp);
            }
            writableDataStore.put(iterDBIDs, dArr);
            if (writableDoubleDataStore != null) {
                writableDoubleDataStore.put(iterDBIDs, logSumExp);
            }
            d += logSumExp;
            iterDBIDs.advance();
        }
        return d / relation.size();
    }

    public static double logSumExp(double[] dArr) {
        double d = dArr[0];
        for (int i = 1; i < dArr.length; i++) {
            double d2 = dArr[i];
            d = d2 > d ? d2 : d;
        }
        double d3 = d - 35.350506209d;
        double d4 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d5 = dArr[i2];
            if (d5 > d3) {
                d4 += d5 < d ? FastMath.exp(d5 - d) : 1.0d;
            }
        }
        return d4 > 1.0d ? d + FastMath.log(d4) : d;
    }

    protected static double logSumExp(double d, double d2) {
        return (d > d2 ? d : d2) + FastMath.log(d > d2 ? FastMath.exp(d2 - d) + 1.0d : FastMath.exp(d - d2) + 1.0d);
    }
}
