package elki.application.benchmark;

import elki.application.AbstractDistanceBasedApplication;
import elki.data.NumberVector;
import elki.data.type.TypeUtil;
import elki.data.type.VectorFieldTypeInformation;
import elki.database.Database;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.DoubleDBIDList;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.query.QueryBuilder;
import elki.database.query.range.RangeSearcher;
import elki.database.relation.Relation;
import elki.database.relation.RelationUtil;
import elki.datasource.DatabaseConnection;
import elki.datasource.bundle.MultipleObjectsBundle;
import elki.distance.Distance;
import elki.index.Index;
import elki.logging.Logging;
import elki.logging.LoggingConfiguration;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.MillisTimeDuration;
import elki.logging.statistics.StringStatistic;
import elki.math.MathUtil;
import elki.math.MeanVariance;
import elki.result.Metadata;
import elki.utilities.Util;
import elki.utilities.datastructures.arrays.ArrayUtil;
import elki.utilities.datastructures.iterator.It;
import elki.utilities.exceptions.IncompatibleDataException;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;
import elki.workflow.InputStep;
import java.util.Arrays;

/* loaded from: input_file:elki/application/benchmark/RangeQueryBenchmark.class */
public class RangeQueryBenchmark<O extends NumberVector> extends AbstractDistanceBasedApplication<O> {
    private static final Logging LOG = Logging.getLogger(RangeQueryBenchmark.class);
    protected double radius;
    protected DatabaseConnection queries;
    protected double sampling;
    protected RandomFactory random;

    /* loaded from: input_file:elki/application/benchmark/RangeQueryBenchmark$Par.class */
    public static class Par<O extends NumberVector> extends AbstractDistanceBasedApplication.Par<O> {
        public static final OptionID RADIUS_ID = new OptionID("rangebench.radius", "Query radius to use a constant radius.");
        public static final OptionID QUERY_ID = new OptionID("rangebench.query", "Data source for the queries. If not set, the queries are taken from the database.");
        public static final OptionID SAMPLING_ID = new OptionID("rangebench.sampling", "Sampling size parameter. If the value is less or equal 1, it is assumed to be the relative share. Larger values will be interpreted as integer sizes. By default, all data will be used.");
        public static final OptionID RANDOM_ID = new OptionID("rangebench.random", "Random generator for sampling.");
        protected DatabaseConnection queries = null;
        protected double sampling = -1.0d;
        protected double radius = Double.NaN;
        protected RandomFactory random;

        @Override // elki.application.AbstractDistanceBasedApplication.Par
        public void configure(Parameterization parameterization) {
            super.configure(parameterization);
            new DoubleParameter(RADIUS_ID).setOptional(true).grab(parameterization, d -> {
                this.radius = d;
            });
            if (Double.isNaN(this.radius)) {
                new ObjectParameter(QUERY_ID, DatabaseConnection.class).setOptional(true).grab(parameterization, databaseConnection -> {
                    this.queries = databaseConnection;
                });
            }
            new DoubleParameter(SAMPLING_ID).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).setOptional(true).grab(parameterization, d2 -> {
                this.sampling = d2;
            });
            new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT).grab(parameterization, randomFactory -> {
                this.random = randomFactory;
            });
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
        public RangeQueryBenchmark<O> m52make() {
            return Double.isNaN(this.radius) ? new RangeQueryBenchmark<>(this.inputstep, this.distance, this.queries, this.sampling, this.random) : new RangeQueryBenchmark<>(this.inputstep, this.distance, this.radius, this.sampling, this.random);
        }
    }

    public RangeQueryBenchmark(InputStep inputStep, Distance<? super O> distance, double d, double d2, RandomFactory randomFactory) {
        super(inputStep, distance);
        this.radius = Double.NaN;
        this.queries = null;
        this.sampling = -1.0d;
        this.radius = d;
        this.sampling = d2;
        this.random = randomFactory;
    }

    public RangeQueryBenchmark(InputStep inputStep, Distance<? super O> distance, DatabaseConnection databaseConnection, double d, RandomFactory randomFactory) {
        super(inputStep, distance);
        this.radius = Double.NaN;
        this.queries = null;
        this.sampling = -1.0d;
        this.queries = databaseConnection;
        this.sampling = d;
        this.random = randomFactory;
    }

    public void run() {
        int run;
        if (!LOG.isStatistics()) {
            LOG.error("Logging level should be at least level STATISTICS (parameter -time) to see any output.");
        }
        Database database = this.inputstep.getDatabase();
        Relation<O> relation = database.getRelation(this.distance.getInputTypeRestriction(), new Object[0]);
        String name = getClass().getName();
        Duration newDuration = LOG.newDuration(name + ".duration");
        MeanVariance meanVariance = new MeanVariance();
        if (!Double.isNaN(this.radius)) {
            RangeSearcher<DBIDRef> rangeByDBID = new QueryBuilder(relation, this.distance).rangeByDBID(this.radius);
            logIndexStatistics(database);
            run = run(rangeByDBID, relation, this.radius, newDuration, meanVariance);
        } else if (this.queries != null) {
            RangeSearcher<O> rangeByObject = new QueryBuilder(relation, this.distance).rangeByObject();
            logIndexStatistics(database);
            run = run(rangeByObject, relation, this.queries, newDuration, meanVariance);
        } else {
            RangeSearcher<DBIDRef> rangeByDBID2 = new QueryBuilder(relation, this.distance).rangeByDBID();
            logIndexStatistics(database);
            run = run(rangeByDBID2, relation, database.getRelation(TypeUtil.NUMBER_VECTOR_FIELD_1D, new Object[0]), newDuration, meanVariance);
        }
        LOG.statistics(newDuration.end());
        if (newDuration instanceof MillisTimeDuration) {
            LOG.statistics(new StringStatistic(name + ".duration.avg", ((newDuration.getDuration() / meanVariance.getCount()) * 1000.0d) + " ns"));
        }
        LOG.statistics(new DoubleStatistic(name + ".results.mean", meanVariance.getMean()));
        LOG.statistics(new DoubleStatistic(name + ".results.std", meanVariance.getPopulationStddev()));
        logIndexStatistics(database);
        LOG.statistics(new LongStatistic(name + ".checksum", run));
    }

    private void logIndexStatistics(Database database) {
        It filter = Metadata.hierarchyOf(database).iterDescendants().filter(Index.class);
        while (filter.valid()) {
            ((Index) filter.get()).logStatistics();
            filter.advance();
        }
    }

    protected int run(RangeSearcher<DBIDRef> rangeSearcher, Relation<O> relation, double d, Duration duration, MeanVariance meanVariance) {
        DBIDs randomSample = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        int i = 0;
        FiniteProgress finiteProgress = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", randomSample.size(), LOG) : null;
        duration.begin();
        DBIDIter iter = randomSample.iter();
        while (iter.valid()) {
            i = Util.mixHashCodes(i, processResult(rangeSearcher.getRange(iter, d), meanVariance));
            LOG.incrementProcessed(finiteProgress);
            iter.advance();
        }
        duration.end();
        LOG.ensureCompleted(finiteProgress);
        return i;
    }

    protected int run(RangeSearcher<DBIDRef> rangeSearcher, Relation<O> relation, Relation<NumberVector> relation2, Duration duration, MeanVariance meanVariance) {
        DBIDs randomSample = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        int i = 0;
        FiniteProgress finiteProgress = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", randomSample.size(), LOG) : null;
        duration.begin();
        DBIDIter iter = randomSample.iter();
        while (iter.valid()) {
            i = Util.mixHashCodes(i, processResult(rangeSearcher.getRange(iter, ((NumberVector) relation2.get(iter)).doubleValue(0)), meanVariance));
            LOG.incrementProcessed(finiteProgress);
            iter.advance();
        }
        duration.end();
        LOG.ensureCompleted(finiteProgress);
        return i;
    }

    protected int run(RangeSearcher<O> rangeSearcher, Relation<O> relation, DatabaseConnection databaseConnection, Duration duration, MeanVariance meanVariance) {
        NumberVector.Factory numberVectorFactory = RelationUtil.getNumberVectorFactory(relation);
        int dimensionality = RelationUtil.dimensionality(relation);
        VectorFieldTypeInformation typeRequest = VectorFieldTypeInformation.typeRequest(NumberVector.class, dimensionality + 1, dimensionality + 1);
        MultipleObjectsBundle loadData = databaseConnection.loadData();
        int i = -1;
        int i2 = 0;
        while (true) {
            if (i2 >= loadData.metaLength()) {
                break;
            }
            if (typeRequest.isAssignableFromType(loadData.meta(i2))) {
                i = i2;
                break;
            }
            i2++;
        }
        if (i < 0) {
            StringBuilder append = new StringBuilder(1000).append("No compatible data type in query input was found. Expected: ").append(typeRequest.toString()).append(" have:");
            for (int i3 = 0; i3 < loadData.metaLength(); i3++) {
                append.append(' ').append(loadData.meta(i3).toString());
            }
            throw new IncompatibleDataException(append.toString());
        }
        int[] sequence = MathUtil.sequence(0, loadData.dataLength());
        int length = (int) (this.sampling <= 1.0d ? this.sampling * sequence.length : this.sampling);
        ArrayUtil.randomShuffle(sequence, this.random.getSingleThreadedRandom(), length);
        int[] copyOf = Arrays.copyOf(sequence, length);
        FiniteProgress finiteProgress = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", copyOf.length, LOG) : null;
        int i4 = 0;
        duration.begin();
        double[] dArr = new double[dimensionality];
        for (int i5 : copyOf) {
            NumberVector numberVector = (NumberVector) loadData.data(i5, i);
            for (int i6 = 0; i6 < dimensionality; i6++) {
                dArr[i6] = numberVector.doubleValue(i6);
            }
            i4 = Util.mixHashCodes(i4, processResult(rangeSearcher.getRange(numberVectorFactory.newNumberVector(dArr), numberVector.doubleValue(dimensionality)), meanVariance));
            LOG.incrementProcessed(finiteProgress);
        }
        duration.end();
        LOG.ensureCompleted(finiteProgress);
        return i4;
    }

    protected int processResult(DoubleDBIDList doubleDBIDList, MeanVariance meanVariance) {
        meanVariance.put(doubleDBIDList.size());
        int i = 0;
        DoubleDBIDListIter iter = doubleDBIDList.iter();
        while (iter.valid()) {
            i += DBIDUtil.asInteger(iter);
            iter.advance();
        }
        return i;
    }

    public static void main(String[] strArr) {
        LoggingConfiguration.setDefaultLevel(Logging.Level.STATISTICS);
        runCLIApplication(RangeQueryBenchmark.class, strArr);
    }
}
