package moa.learners.featureanalysis;

import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.Iterator;
import moa.capabilities.CapabilitiesHandler;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.splitcriteria.InfoGainSplitCriterion;
import moa.classifiers.trees.HoeffdingTree;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/learners/featureanalysis/FeatureImportanceHoeffdingTree.class */
public class FeatureImportanceHoeffdingTree extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler, FeatureImportanceClassifier {
    protected double[] featureImportances;
    protected static final int FEATURE_IMPORTANCE_MDI = 0;
    protected static final int FEATURE_IMPORTANCE_COVER = 1;
    static final /* synthetic */ boolean $assertionsDisabled;
    public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "Decision Tree learner.", HoeffdingTree.class, "HoeffdingTree");
    public MultiChoiceOption featureImportanceOption = new MultiChoiceOption("featureImportance", 'o', "Which method to use for feature importance estimations.", new String[]{"MDI", "COVER"}, new String[]{"MDI", "COVER"}, 0);
    protected HoeffdingTree treeLearner = null;
    protected int nodeCountAtLastFeatureImportanceInquiry = 0;
    protected int featureImportancesInquiries = 0;

    @Override // moa.learners.featureanalysis.FeatureImportanceClassifier
    public double[] getFeatureImportances(boolean z) {
        if (this.treeLearner.getTreeRoot() != null && this.treeLearner.getNodeCount() > this.nodeCountAtLastFeatureImportanceInquiry) {
            this.featureImportancesInquiries++;
            this.featureImportances = new double[this.featureImportances.length];
            this.nodeCountAtLastFeatureImportanceInquiry = this.treeLearner.getNodeCount();
            switch (this.featureImportanceOption.getChosenIndex()) {
                case 0:
                    calcMeanDecreaseImpurity(this.treeLearner.getTreeRoot());
                    break;
                case 1:
                    calcMeanCover(this.treeLearner.getTreeRoot());
                    break;
            }
            if (z) {
                double sum = Utils.sum(this.featureImportances);
                for (int i = 0; i < this.featureImportances.length; i++) {
                    double[] dArr = this.featureImportances;
                    int i2 = i;
                    dArr[i2] = dArr[i2] / sum;
                }
            }
        }
        return this.featureImportances;
    }

    @Override // moa.learners.featureanalysis.FeatureImportanceClassifier
    public int[] getTopKFeatures(int i, boolean z) {
        if (getFeatureImportances(z) == null) {
            return null;
        }
        if (i > getFeatureImportances(z).length) {
            i = getFeatureImportances(z).length;
        }
        int[] iArr = new int[i];
        double[] dArr = new double[getFeatureImportances(z).length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = getFeatureImportances(z)[i2];
        }
        for (int i3 = 0; i3 < i; i3++) {
            int maxIndex = Utils.maxIndex(dArr);
            iArr[i3] = maxIndex;
            dArr[maxIndex] = -1.0d;
        }
        return iArr;
    }

    private void calcMeanCover(HoeffdingTree.Node node) {
        if (node instanceof HoeffdingTree.SplitNode) {
            HoeffdingTree.SplitNode splitNode = (HoeffdingTree.SplitNode) node;
            int i = splitNode.getSplitTest().getAttsTestDependsOn()[0];
            if (this.featureImportances.length <= i) {
                System.out.println("Error with attributeIndex");
                if (!$assertionsDisabled && this.featureImportances.length > i) {
                    throw new AssertionError();
                }
            }
            double[] dArr = this.featureImportances;
            dArr[i] = dArr[i] + calcNodeCover(splitNode);
            Iterator<HoeffdingTree.Node> it = splitNode.getChildren().iterator();
            while (it.hasNext()) {
                HoeffdingTree.Node next = it.next();
                if (next != null) {
                    calcMeanCover(next);
                }
            }
        }
    }

    public double calcNodeCover(HoeffdingTree.SplitNode splitNode) {
        return Utils.sum(splitNode.getObservedClassDistributionAtLeavesReachableThroughThisNode());
    }

    private void calcMeanDecreaseImpurity(HoeffdingTree.Node node) {
        if (node instanceof HoeffdingTree.SplitNode) {
            HoeffdingTree.SplitNode splitNode = (HoeffdingTree.SplitNode) node;
            int i = splitNode.getSplitTest().getAttsTestDependsOn()[0];
            if (this.featureImportances.length <= i) {
                System.out.println("Error with attributeIndex");
                if (!$assertionsDisabled && this.featureImportances.length > i) {
                    throw new AssertionError();
                }
            }
            double[] dArr = this.featureImportances;
            dArr[i] = dArr[i] + calcNodeDecreaseImpurity(splitNode);
            Iterator<HoeffdingTree.Node> it = splitNode.getChildren().iterator();
            while (it.hasNext()) {
                HoeffdingTree.Node next = it.next();
                if (next != null) {
                    calcMeanDecreaseImpurity(next);
                }
            }
        }
    }

    public double calcNodeDecreaseImpurity(HoeffdingTree.SplitNode splitNode) {
        double[] observedClassDistributionAtLeavesReachableThroughThisNode = splitNode.getObservedClassDistributionAtLeavesReachableThroughThisNode();
        double computeEntropy = InfoGainSplitCriterion.computeEntropy(observedClassDistributionAtLeavesReachableThroughThisNode);
        double d = 0.0d;
        double sum = Utils.sum(observedClassDistributionAtLeavesReachableThroughThisNode);
        Iterator<HoeffdingTree.Node> it = splitNode.getChildren().iterator();
        while (it.hasNext()) {
            HoeffdingTree.Node next = it.next();
            if (next != null) {
                d += (((int) Utils.sum(next.getObservedClassDistributionAtLeavesReachableThroughThisNode())) / sum) * InfoGainSplitCriterion.computeEntropy(next.getObservedClassDistributionAtLeavesReachableThroughThisNode());
            }
        }
        return computeEntropy - d;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        return this.treeLearner.getVotesForInstance(instance);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.featureImportances = null;
        this.nodeCountAtLastFeatureImportanceInquiry = 0;
        this.featureImportancesInquiries = 0;
        this.treeLearner = (HoeffdingTree) getPreparedClassOption(this.treeLearnerOption);
        this.treeLearner.resetLearning();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.featureImportances == null) {
            this.featureImportances = new double[instance.numAttributes() - 1];
        }
        this.treeLearner.trainOnInstance(instance);
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return this.treeLearner.getModelMeasurements();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        this.treeLearner.getModelDescription(sb, i);
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        if (this.treeLearner == null) {
            return false;
        }
        return this.treeLearner.isRandomizable();
    }

    static {
        $assertionsDisabled = !FeatureImportanceHoeffdingTree.class.desiredAssertionStatus();
    }
}
