package moa.learners.featureanalysis;

import com.yahoo.labs.samoa.instances.Instance;
import moa.capabilities.CapabilitiesHandler;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.trees.HoeffdingTree;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/learners/featureanalysis/FeatureImportanceHoeffdingTreeEnsemble.class */
public class FeatureImportanceHoeffdingTreeEnsemble extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler, FeatureImportanceClassifier {
    public ClassOption ensembleLearnerOption = new ClassOption("ensembleLearner", 'l', "Ensemble learner to train and analyze.", Classifier.class, "moa.classifiers.meta.AdaptiveRandomForest");
    public ClassOption hoeffdingTreeFeatureImportanceOption = new ClassOption("hoeffdingTreeFeatureImportance", 't', "Hoeffding Tree object to use. Its learner option is overridden by the ensemble base tree model.", FeatureImportanceHoeffdingTree.class, "FeatureImportanceHoeffdingTree");
    protected Classifier ensemble;
    protected FeatureImportanceHoeffdingTree htFeatureImportanceBase;
    protected double[] featureImportances;

    @Override // moa.learners.featureanalysis.FeatureImportanceClassifier
    public double[] getFeatureImportances(boolean z) {
        Classifier[] classifierArr = (Classifier[]) this.ensemble.getSublearners();
        if (classifierArr != null) {
            FeatureImportanceHoeffdingTree[] featureImportanceHoeffdingTreeArr = new FeatureImportanceHoeffdingTree[classifierArr.length];
            for (int i = 0; i < classifierArr.length; i++) {
                featureImportanceHoeffdingTreeArr[i] = (FeatureImportanceHoeffdingTree) this.htFeatureImportanceBase.copy();
                featureImportanceHoeffdingTreeArr[i].treeLearner = (HoeffdingTree) classifierArr[i];
                if (featureImportanceHoeffdingTreeArr[i].featureImportances == null) {
                    if (this.featureImportances != null) {
                        featureImportanceHoeffdingTreeArr[i].featureImportances = new double[this.featureImportances.length];
                    } else {
                        System.err.println("Unable to infer the number of features. trainOnInstance() must be invoked prior to getFeatureImportances()");
                    }
                }
            }
            if (this.featureImportances != null) {
                this.featureImportances = new double[this.featureImportances.length];
            }
            for (int i2 = 0; i2 < classifierArr.length; i2++) {
                double[] featureImportances = featureImportanceHoeffdingTreeArr[i2].getFeatureImportances(z);
                for (int i3 = 0; i3 < this.featureImportances.length; i3++) {
                    double[] dArr = this.featureImportances;
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + (Double.isNaN(featureImportances[i3]) ? 0.0d : featureImportances[i3]);
                }
            }
            if (z) {
                double sum = Utils.sum(this.featureImportances);
                for (int i5 = 0; i5 < this.featureImportances.length; i5++) {
                    double[] dArr2 = this.featureImportances;
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] / sum;
                }
            }
        }
        return this.featureImportances;
    }

    @Override // moa.learners.featureanalysis.FeatureImportanceClassifier
    public int[] getTopKFeatures(int i, boolean z) {
        if (getFeatureImportances(z) == null) {
            return null;
        }
        if (i > getFeatureImportances(z).length) {
            i = getFeatureImportances(z).length;
        }
        int[] iArr = new int[i];
        double[] dArr = new double[getFeatureImportances(z).length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = getFeatureImportances(z)[i2];
        }
        for (int i3 = 0; i3 < i; i3++) {
            int maxIndex = Utils.maxIndex(dArr);
            iArr[i3] = maxIndex;
            dArr[maxIndex] = -1.0d;
        }
        return iArr;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = (Classifier) getPreparedClassOption(this.ensembleLearnerOption);
        this.ensemble.resetLearning();
        this.htFeatureImportanceBase = (FeatureImportanceHoeffdingTree) getPreparedClassOption(this.hoeffdingTreeFeatureImportanceOption);
        if (this.ensemble.getSubClassifiers() == null) {
            System.err.println("The classifier is not an ensemble or does not implement the getSubClassifiers() method. ");
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.featureImportances == null) {
            this.featureImportances = new double[instance.numAttributes() - 1];
        }
        this.ensemble.trainOnInstance(instance);
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        return this.ensemble.getVotesForInstance(instance);
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return this.ensemble.getModelMeasurements();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }
}
