package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import moa.AbstractMOAObject;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.classifiers.trees.ARFHoeffdingTree;
import moa.core.DoubleVector;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.evaluation.BasicClassificationPerformanceEvaluator;
import moa.options.ClassOption;
import org.apache.log4j.Priority;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AdaptiveRandomForest.class */
public class AdaptiveRandomForest extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler {
    private static final long serialVersionUID = 1;
    public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "Random Forest Tree.", ARFHoeffdingTree.class, "ARFHoeffdingTree -e 2000000 -g 50 -c 0.01");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of trees.", 100, 1, Integer.MAX_VALUE);
    public MultiChoiceOption mFeaturesModeOption = new MultiChoiceOption("mFeaturesMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", "Percentage (M * (m / 100))"}, new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 3);
    public IntOption mFeaturesPerTreeSizeOption = new IntOption("mFeaturesPerTreeSize", 'm', "Number of features allowed considered for each split. Negative values corresponds to M - m", 60, Priority.ALL_INT, Integer.MAX_VALUE);
    public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0d, 1.0d, 3.4028234663852886E38d);
    public IntOption numberOfJobsOption = new IntOption("numberOfJobs", 'j', "Total number of concurrent jobs used for processing (-1 = as much as possible, 0 = do not use multithreading)", 1, -1, Integer.MAX_VALUE);
    public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-3");
    public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-2");
    public FlagOption disableWeightedVote = new FlagOption("disableWeightedVote", 'w', "Should use weighted voting?");
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'u', "Should use drift detection? If disabled then bkg learner is also disabled");
    public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', "Should use bkg learner? If disabled then reset tree immediately.");
    protected static final int FEATURES_M = 0;
    protected static final int FEATURES_SQRT = 1;
    protected static final int FEATURES_SQRT_INV = 2;
    protected static final int FEATURES_PERCENT = 3;
    protected static final int SINGLE_THREAD = 0;
    protected ARFBaseLearner[] ensemble;
    protected long instancesSeen;
    protected int subspaceSize;
    protected BasicClassificationPerformanceEvaluator evaluator;
    private ExecutorService executor;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AdaptiveRandomForest$ARFBaseLearner.class */
    public final class ARFBaseLearner extends AbstractMOAObject {
        public int indexOriginal;
        public long createdOn;
        public long lastDriftOn;
        public long lastWarningOn;
        public ARFHoeffdingTree classifier;
        public boolean isBackgroundLearner;
        protected ClassOption driftOption;
        protected ClassOption warningOption;
        protected ChangeDetector driftDetectionMethod;
        protected ChangeDetector warningDetectionMethod;
        public boolean useBkgLearner;
        public boolean useDriftDetector;
        protected ARFBaseLearner bkgLearner;
        public BasicClassificationPerformanceEvaluator evaluator;
        protected int numberOfDriftsDetected;
        protected int numberOfWarningsDetected;

        private void init(int i, ARFHoeffdingTree aRFHoeffdingTree, BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, boolean z3) {
            this.indexOriginal = i;
            this.createdOn = j;
            this.lastDriftOn = 0L;
            this.lastWarningOn = 0L;
            this.classifier = aRFHoeffdingTree;
            this.evaluator = basicClassificationPerformanceEvaluator;
            this.useBkgLearner = z;
            this.useDriftDetector = z2;
            this.numberOfDriftsDetected = 0;
            this.numberOfWarningsDetected = 0;
            this.isBackgroundLearner = z3;
            if (this.useDriftDetector) {
                this.driftOption = classOption;
                this.driftDetectionMethod = ((ChangeDetector) AdaptiveRandomForest.this.getPreparedClassOption(this.driftOption)).copy();
            }
            if (this.useBkgLearner) {
                this.warningOption = classOption2;
                this.warningDetectionMethod = ((ChangeDetector) AdaptiveRandomForest.this.getPreparedClassOption(this.warningOption)).copy();
            }
        }

        public ARFBaseLearner(int i, ARFHoeffdingTree aRFHoeffdingTree, BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, boolean z3) {
            init(i, aRFHoeffdingTree, basicClassificationPerformanceEvaluator, j, z, z2, classOption, classOption2, z3);
        }

        public void reset() {
            if (!this.useBkgLearner || this.bkgLearner == null) {
                this.classifier.resetLearning();
                this.createdOn = AdaptiveRandomForest.this.instancesSeen;
                this.driftDetectionMethod = ((ChangeDetector) AdaptiveRandomForest.this.getPreparedClassOption(this.driftOption)).copy();
            } else {
                this.classifier = this.bkgLearner.classifier;
                this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod;
                this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod;
                this.evaluator = this.bkgLearner.evaluator;
                this.createdOn = this.bkgLearner.createdOn;
                this.bkgLearner = null;
            }
            this.evaluator.reset();
        }

        public void trainOnInstance(Instance instance, double d, long j) {
            Instance copy = instance.copy();
            copy.setWeight(instance.weight() * d);
            this.classifier.trainOnInstance(copy);
            if (this.bkgLearner != null) {
                this.bkgLearner.classifier.trainOnInstance(instance);
            }
            if (!this.useDriftDetector || this.isBackgroundLearner) {
                return;
            }
            boolean correctlyClassifies = this.classifier.correctlyClassifies(instance);
            if (this.useBkgLearner) {
                this.warningDetectionMethod.input(correctlyClassifies ? 0.0d : 1.0d);
                if (this.warningDetectionMethod.getChange()) {
                    this.lastWarningOn = j;
                    this.numberOfWarningsDetected++;
                    ARFHoeffdingTree aRFHoeffdingTree = (ARFHoeffdingTree) this.classifier.copy();
                    aRFHoeffdingTree.resetLearning();
                    BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator = (BasicClassificationPerformanceEvaluator) this.evaluator.copy();
                    basicClassificationPerformanceEvaluator.reset();
                    this.bkgLearner = new ARFBaseLearner(this.indexOriginal, aRFHoeffdingTree, basicClassificationPerformanceEvaluator, j, this.useBkgLearner, this.useDriftDetector, this.driftOption, this.warningOption, true);
                    this.warningDetectionMethod = ((ChangeDetector) AdaptiveRandomForest.this.getPreparedClassOption(this.warningOption)).copy();
                }
            }
            this.driftDetectionMethod.input(correctlyClassifies ? 0.0d : 1.0d);
            if (this.driftDetectionMethod.getChange()) {
                this.lastDriftOn = j;
                this.numberOfDriftsDetected++;
                reset();
            }
        }

        public double[] getVotesForInstance(Instance instance) {
            return new DoubleVector(this.classifier.getVotesForInstance(instance)).getArrayRef();
        }

        @Override // moa.MOAObject
        public void getDescription(StringBuilder sb, int i) {
        }
    }

    /* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AdaptiveRandomForest$TrainingRunnable.class */
    protected class TrainingRunnable implements Runnable, Callable<Integer> {
        private final ARFBaseLearner learner;
        private final Instance instance;
        private final double weight;
        private final long instancesSeen;

        public TrainingRunnable(ARFBaseLearner aRFBaseLearner, Instance instance, double d, long j) {
            this.learner = aRFBaseLearner;
            this.instance = instance;
            this.weight = d;
            this.instancesSeen = j;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.learner.trainOnInstance(this.instance, this.weight, this.instancesSeen);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Integer call() {
            run();
            return 0;
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Adaptive Random Forest algorithm for evolving data streams from Gomes et al.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = null;
        this.subspaceSize = 0;
        this.instancesSeen = 0L;
        this.evaluator = new BasicClassificationPerformanceEvaluator();
        int availableProcessors = this.numberOfJobsOption.getValue() == -1 ? Runtime.getRuntime().availableProcessors() : this.numberOfJobsOption.getValue();
        if (availableProcessors == 0 || availableProcessors == 1) {
            return;
        }
        this.executor = Executors.newFixedThreadPool(availableProcessors);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        this.instancesSeen++;
        if (this.ensemble == null) {
            initEnsemble(instance);
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i].evaluator.addResult(new InstanceExample(instance), new DoubleVector(this.ensemble[i].getVotesForInstance(instance)).getArrayRef());
            int poisson = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom);
            if (poisson > 0) {
                if (this.executor != null) {
                    arrayList.add(new TrainingRunnable(this.ensemble[i], instance, poisson, this.instancesSeen));
                } else {
                    this.ensemble[i].trainOnInstance(instance, poisson, this.instancesSeen);
                }
            }
        }
        if (this.executor != null) {
            try {
                this.executor.invokeAll(arrayList);
            } catch (InterruptedException e) {
                throw new RuntimeException("Could not call invokeAll() on training threads.");
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        Instance copy = instance.copy();
        if (this.ensemble == null) {
            initEnsemble(copy);
        }
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; i++) {
            DoubleVector doubleVector2 = new DoubleVector(this.ensemble[i].getVotesForInstance(copy));
            if (doubleVector2.sumOfValues() > 0.0d) {
                doubleVector2.normalize();
                double value = this.ensemble[i].evaluator.getPerformanceMeasurements()[1].getValue();
                if (!this.disableWeightedVote.isSet() && value > 0.0d) {
                    for (int i2 = 0; i2 < doubleVector2.numValues(); i2++) {
                        doubleVector2.setValue(i2, doubleVector2.getValue(i2) * value);
                    }
                }
                doubleVector.addValues(doubleVector2);
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void initEnsemble(Instance instance) {
        int value = this.ensembleSizeOption.getValue();
        this.ensemble = new ARFBaseLearner[value];
        BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator = new BasicClassificationPerformanceEvaluator();
        this.subspaceSize = this.mFeaturesPerTreeSizeOption.getValue();
        int numAttributes = instance.numAttributes() - 1;
        switch (this.mFeaturesModeOption.getChosenIndex()) {
            case 1:
                this.subspaceSize = ((int) Math.round(Math.sqrt(numAttributes))) + 1;
                break;
            case 2:
                this.subspaceSize = numAttributes - ((int) Math.round(Math.sqrt(numAttributes) + 1.0d));
                break;
            case 3:
                this.subspaceSize = (int) Math.round(numAttributes * (this.subspaceSize < 0 ? (100 + this.subspaceSize) / 100.0d : this.subspaceSize / 100.0d));
                break;
        }
        if (this.subspaceSize < 0) {
            this.subspaceSize = numAttributes + this.subspaceSize;
        }
        if (this.subspaceSize <= 0) {
            this.subspaceSize = 1;
        }
        if (this.subspaceSize > numAttributes) {
            this.subspaceSize = numAttributes;
        }
        ARFHoeffdingTree aRFHoeffdingTree = (ARFHoeffdingTree) getPreparedClassOption(this.treeLearnerOption);
        aRFHoeffdingTree.resetLearning();
        for (int i = 0; i < value; i++) {
            aRFHoeffdingTree.subspaceSizeOption.setValue(this.subspaceSize);
            this.ensemble[i] = new ARFBaseLearner(i, (ARFHoeffdingTree) aRFHoeffdingTree.copy(), (BasicClassificationPerformanceEvaluator) basicClassificationPerformanceEvaluator.copy(), this.instancesSeen, !this.disableBackgroundLearnerOption.isSet(), !this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, false);
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return getClass() == AdaptiveRandomForest.class ? new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE) : new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }

    @Override // moa.classifiers.AbstractClassifier, moa.learners.Learner
    public Classifier[] getSublearners() {
        Classifier[] classifierArr = new Classifier[this.ensemble.length];
        for (int i = 0; i < classifierArr.length; i++) {
            classifierArr[i] = this.ensemble[i].classifier;
        }
        return classifierArr;
    }
}
