package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.AbstractMOAObject;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Regressor;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.classifiers.trees.ARFFIMTDD;
import moa.core.DoubleVector;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.evaluation.BasicRegressionPerformanceEvaluator;
import moa.options.ClassOption;
import org.apache.log4j.Priority;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AdaptiveRandomForestRegressor.class */
public class AdaptiveRandomForestRegressor extends AbstractClassifier implements Regressor {
    private static final long serialVersionUID = 1;
    public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "Random Forest Tree.", ARFFIMTDD.class, "ARFFIMTDD -s VarianceReductionSplitCriterion -g 50 -c 0.01");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of trees.", 100, 1, Integer.MAX_VALUE);
    public MultiChoiceOption mFeaturesModeOption = new MultiChoiceOption("mFeaturesMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", "Percentage (M * (m / 100))"}, new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 3);
    public IntOption mFeaturesPerTreeSizeOption = new IntOption("mFeaturesPerTreeSize", 'm', "Number of features allowed considered for each split. Negative values corresponds to M - m", 60, Priority.ALL_INT, Integer.MAX_VALUE);
    public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0d, 1.0d, 3.4028234663852886E38d);
    public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-3");
    public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-2");
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'u', "Should use drift detection? If disabled then bkg learner is also disabled");
    public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', "Should use bkg learner? If disabled then reset tree immediately.");
    protected static final int FEATURES_M = 0;
    protected static final int FEATURES_SQRT = 1;
    protected static final int FEATURES_SQRT_INV = 2;
    protected static final int FEATURES_PERCENT = 3;
    protected ARFFIMTDDBaseLearner[] ensemble;
    protected long instancesSeen;
    protected int subspaceSize;
    protected BasicRegressionPerformanceEvaluator evaluator;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AdaptiveRandomForestRegressor$ARFFIMTDDBaseLearner.class */
    public final class ARFFIMTDDBaseLearner extends AbstractMOAObject {
        public int indexOriginal;
        public long createdOn;
        public long lastDriftOn;
        public long lastWarningOn;
        public ARFFIMTDD classifier;
        public boolean isBackgroundLearner;
        protected ClassOption driftOption;
        protected ClassOption warningOption;
        protected ChangeDetector driftDetectionMethod;
        protected ChangeDetector warningDetectionMethod;
        public boolean useBkgLearner;
        public boolean useDriftDetector;
        protected ARFFIMTDDBaseLearner bkgLearner;
        public BasicRegressionPerformanceEvaluator evaluator;
        protected int numberOfDriftsDetected;
        protected int numberOfWarningsDetected;

        private void init(int i, ARFFIMTDD arffimtdd, BasicRegressionPerformanceEvaluator basicRegressionPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, boolean z3) {
            this.indexOriginal = i;
            this.createdOn = j;
            this.lastDriftOn = 0L;
            this.lastWarningOn = 0L;
            this.classifier = arffimtdd;
            this.evaluator = basicRegressionPerformanceEvaluator;
            this.useBkgLearner = z;
            this.useDriftDetector = z2;
            this.numberOfDriftsDetected = 0;
            this.numberOfWarningsDetected = 0;
            this.isBackgroundLearner = z3;
            if (this.useDriftDetector) {
                this.driftOption = classOption;
                this.driftDetectionMethod = ((ChangeDetector) AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.driftOption)).copy();
            }
            if (this.useBkgLearner) {
                this.warningOption = classOption2;
                this.warningDetectionMethod = ((ChangeDetector) AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.warningOption)).copy();
            }
        }

        public ARFFIMTDDBaseLearner(int i, ARFFIMTDD arffimtdd, BasicRegressionPerformanceEvaluator basicRegressionPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, boolean z3) {
            init(i, arffimtdd, basicRegressionPerformanceEvaluator, j, z, z2, classOption, classOption2, z3);
        }

        public void reset() {
            if (!this.useBkgLearner || this.bkgLearner == null) {
                this.classifier.resetLearning();
                this.createdOn = AdaptiveRandomForestRegressor.this.instancesSeen;
                this.driftDetectionMethod = ((ChangeDetector) AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.driftOption)).copy();
            } else {
                this.classifier = this.bkgLearner.classifier;
                this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod;
                this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod;
                this.evaluator = this.bkgLearner.evaluator;
                this.createdOn = this.bkgLearner.createdOn;
                this.bkgLearner = null;
            }
            this.evaluator.reset();
        }

        public void trainOnInstance(Instance instance, double d, long j) {
            Instance copy = instance.copy();
            copy.setWeight(instance.weight() * d);
            this.classifier.trainOnInstance(copy);
            if (this.bkgLearner != null) {
                this.bkgLearner.classifier.trainOnInstance(instance);
            }
            if (!this.useDriftDetector || this.isBackgroundLearner) {
                return;
            }
            double d2 = this.classifier.getVotesForInstance(instance)[0];
            if (this.useBkgLearner) {
                this.warningDetectionMethod.input(d2);
                if (this.warningDetectionMethod.getChange()) {
                    this.lastWarningOn = j;
                    this.numberOfWarningsDetected++;
                    ARFFIMTDD arffimtdd = (ARFFIMTDD) this.classifier.copy();
                    arffimtdd.resetLearning();
                    BasicRegressionPerformanceEvaluator basicRegressionPerformanceEvaluator = (BasicRegressionPerformanceEvaluator) this.evaluator.copy();
                    basicRegressionPerformanceEvaluator.reset();
                    this.bkgLearner = new ARFFIMTDDBaseLearner(this.indexOriginal, arffimtdd, basicRegressionPerformanceEvaluator, j, this.useBkgLearner, this.useDriftDetector, this.driftOption, this.warningOption, true);
                    this.warningDetectionMethod = ((ChangeDetector) AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.warningOption)).copy();
                }
            }
            this.driftDetectionMethod.input(d2);
            if (this.driftDetectionMethod.getChange()) {
                this.lastDriftOn = j;
                this.numberOfDriftsDetected++;
                reset();
            }
        }

        public double[] getVotesForInstance(Instance instance) {
            return this.classifier.getVotesForInstance(instance);
        }

        @Override // moa.MOAObject
        public void getDescription(StringBuilder sb, int i) {
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Adaptive Random Forest Regressor algorithm for evolving data streams from Gomes et al.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = null;
        this.subspaceSize = 0;
        this.instancesSeen = 0L;
        this.evaluator = new BasicRegressionPerformanceEvaluator();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        this.instancesSeen++;
        if (this.ensemble == null) {
            initEnsemble(instance);
        }
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i].evaluator.addResult(new InstanceExample(instance), new DoubleVector(this.ensemble[i].getVotesForInstance(instance)).getArrayRef());
            int poisson = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom);
            if (poisson > 0) {
                this.ensemble[i].trainOnInstance(instance, poisson, this.instancesSeen);
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        Instance copy = instance.copy();
        if (this.ensemble == null) {
            initEnsemble(copy);
        }
        double d = 0.0d;
        DoubleVector doubleVector = new DoubleVector();
        DoubleVector doubleVector2 = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; i++) {
            double d2 = this.ensemble[i].getVotesForInstance(copy)[0];
            doubleVector.addToValue(i, this.instancesSeen - this.ensemble[i].createdOn);
            doubleVector2.addToValue(i, this.ensemble[i].evaluator.getSquareError());
            d += d2;
        }
        return new double[]{d / this.ensemble.length};
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void initEnsemble(Instance instance) {
        int value = this.ensembleSizeOption.getValue();
        this.ensemble = new ARFFIMTDDBaseLearner[value];
        BasicRegressionPerformanceEvaluator basicRegressionPerformanceEvaluator = new BasicRegressionPerformanceEvaluator();
        this.subspaceSize = this.mFeaturesPerTreeSizeOption.getValue();
        int numAttributes = instance.numAttributes() - 1;
        switch (this.mFeaturesModeOption.getChosenIndex()) {
            case 1:
                this.subspaceSize = ((int) Math.round(Math.sqrt(numAttributes))) + 1;
                break;
            case 2:
                this.subspaceSize = numAttributes - ((int) Math.round(Math.sqrt(numAttributes) + 1.0d));
                break;
            case 3:
                this.subspaceSize = (int) Math.round(numAttributes * (this.subspaceSize < 0 ? (100 + this.subspaceSize) / 100.0d : this.subspaceSize / 100.0d));
                break;
        }
        if (this.subspaceSize < 0) {
            this.subspaceSize = numAttributes + this.subspaceSize;
        }
        if (this.subspaceSize <= 0) {
            this.subspaceSize = 1;
        }
        if (this.subspaceSize > numAttributes) {
            this.subspaceSize = numAttributes;
        }
        ARFFIMTDD arffimtdd = (ARFFIMTDD) getPreparedClassOption(this.treeLearnerOption);
        arffimtdd.resetLearning();
        for (int i = 0; i < value; i++) {
            arffimtdd.subspaceSizeOption.setValue(this.subspaceSize);
            this.ensemble[i] = new ARFFIMTDDBaseLearner(i, (ARFFIMTDD) arffimtdd.copy(), (BasicRegressionPerformanceEvaluator) basicRegressionPerformanceEvaluator.copy(), this.instancesSeen, !this.disableBackgroundLearnerOption.isSet(), !this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, false);
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }
}
