package elki.datasource.filter.transform;

import elki.data.ClassLabel;
import elki.data.NumberVector;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeUtil;
import elki.data.type.VectorFieldTypeInformation;
import elki.datasource.bundle.MultipleObjectsBundle;
import elki.datasource.filter.ObjectFilter;
import elki.datasource.filter.typeconversions.ClassLabelFilter;
import elki.logging.Logging;
import elki.math.linearalgebra.VMath;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:elki/datasource/filter/transform/AbstractSupervisedProjectionVectorFilter.class */
public abstract class AbstractSupervisedProjectionVectorFilter<V extends NumberVector> implements ObjectFilter {
    protected int tdim;

    /* loaded from: input_file:elki/datasource/filter/transform/AbstractSupervisedProjectionVectorFilter$Par.class */
    public static abstract class Par<V extends NumberVector> implements Parameterizer {
        public static final OptionID P_ID = new OptionID("projection.dim", "Projection dimensionality");
        protected int tdim;

        public void configure(Parameterization parameterization) {
            new IntParameter(P_ID, 2).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).grab(parameterization, i -> {
                this.tdim = i;
            });
        }
    }

    public AbstractSupervisedProjectionVectorFilter(int i) {
        this.tdim = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public MultipleObjectsBundle filter(MultipleObjectsBundle multipleObjectsBundle) {
        int dataLength = multipleObjectsBundle.dataLength();
        if (dataLength == 0) {
            return multipleObjectsBundle;
        }
        List list = null;
        int i = 0;
        while (true) {
            if (i >= multipleObjectsBundle.metaLength()) {
                break;
            }
            SimpleTypeInformation meta = multipleObjectsBundle.meta(i);
            List column = multipleObjectsBundle.getColumn(i);
            if (TypeUtil.CLASSLABEL.isAssignableFromType(meta)) {
                list = column;
                break;
            }
            i++;
        }
        if (list == null) {
            getLogger().warning("No class label column found (try " + ClassLabelFilter.class.getSimpleName() + ") -- cannot run " + getClass().getSimpleName());
            return multipleObjectsBundle;
        }
        boolean z = false;
        MultipleObjectsBundle multipleObjectsBundle2 = new MultipleObjectsBundle();
        for (int i2 = 0; i2 < multipleObjectsBundle.metaLength(); i2++) {
            VectorFieldTypeInformation meta2 = multipleObjectsBundle.meta(i2);
            List column2 = multipleObjectsBundle.getColumn(i2);
            if (TypeUtil.NUMBER_VECTOR_FIELD.isAssignableFromType(meta2)) {
                VectorFieldTypeInformation vectorFieldTypeInformation = meta2;
                NumberVector.Factory factory = vectorFieldTypeInformation.getFactory();
                int dimensionality = vectorFieldTypeInformation.getDimensionality();
                if (this.tdim > dimensionality) {
                    if (getLogger().isVerbose()) {
                        getLogger().verbose("Setting projection dimension to original dimension: projection dimension: " + this.tdim + " larger than original dimension: " + dimensionality);
                    }
                    this.tdim = dimensionality;
                }
                try {
                    double[][] computeProjectionMatrix = computeProjectionMatrix(column2, list, dimensionality);
                    for (int i3 = 0; i3 < dataLength; i3++) {
                        column2.set(i3, factory.newNumberVector(VMath.times(computeProjectionMatrix, ((NumberVector) column2.get(i3)).toArray())));
                    }
                    multipleObjectsBundle2.appendColumn(convertedType(meta2, factory), column2);
                    z = true;
                } catch (Exception e) {
                    getLogger().error("Projection failed -- continuing with unprojected data!", e);
                    multipleObjectsBundle2.appendColumn(meta2, column2);
                }
            } else {
                multipleObjectsBundle2.appendColumn(meta2, column2);
            }
        }
        if (z) {
            return multipleObjectsBundle2;
        }
        getLogger().warning("No vector field of fixed dimensionality found.");
        return multipleObjectsBundle;
    }

    protected SimpleTypeInformation<?> convertedType(SimpleTypeInformation<?> simpleTypeInformation, NumberVector.Factory<V> factory) {
        return new VectorFieldTypeInformation(factory, this.tdim);
    }

    protected abstract Logging getLogger();

    protected abstract double[][] computeProjectionMatrix(List<V> list, List<? extends ClassLabel> list2, int i);

    /* JADX INFO: Access modifiers changed from: protected */
    public <O> Map<O, IntList> partition(List<? extends O> list) {
        HashMap hashMap = new HashMap();
        int i = 0;
        for (O o : list) {
            IntArrayList intArrayList = (IntList) hashMap.get(o);
            if (intArrayList == null) {
                intArrayList = new IntArrayList();
                hashMap.put(o, intArrayList);
            }
            intArrayList.add(i);
            i++;
        }
        return hashMap;
    }
}
