package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/OzaBoost.class */
public class OzaBoost extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
    public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson.");
    protected Classifier[] ensemble;
    protected double[] scms;
    protected double[] swms;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Incremental on-line boosting of Oza and Russell.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        classifier.resetLearning();
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i] = classifier.copy();
        }
        this.scms = new double[this.ensemble.length];
        this.swms = new double[this.ensemble.length];
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        double d;
        double d2;
        double d3;
        double d4;
        double d5 = 1.0d;
        for (int i = 0; i < this.ensemble.length; i++) {
            double poisson = this.pureBoostOption.isSet() ? d5 : MiscUtils.poisson(d5, this.classifierRandom);
            if (poisson > 0.0d) {
                Instance copy = instance.copy();
                copy.setWeight(instance.weight() * poisson);
                this.ensemble[i].trainOnInstance(copy);
            }
            if (this.ensemble[i].correctlyClassifies(instance)) {
                double[] dArr = this.scms;
                int i2 = i;
                dArr[i2] = dArr[i2] + d5;
                d = d5;
                d2 = this.trainingWeightSeenByModel;
                d3 = 2.0d;
                d4 = this.scms[i];
            } else {
                double[] dArr2 = this.swms;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + d5;
                d = d5;
                d2 = this.trainingWeightSeenByModel;
                d3 = 2.0d;
                d4 = this.swms[i];
            }
            d5 = d * (d2 / (d3 * d4));
        }
    }

    protected double getEnsembleMemberWeight(int i) {
        double d = this.swms[i] / (this.scms[i] + this.swms[i]);
        if (d == 0.0d || d > 0.5d) {
            return 0.0d;
        }
        return Math.log(1.0d / (d / (1.0d - d)));
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; i++) {
            double ensembleMemberWeight = getEnsembleMemberWeight(i);
            if (ensembleMemberWeight <= 0.0d) {
                break;
            }
            DoubleVector doubleVector2 = new DoubleVector(this.ensemble[i].getVotesForInstance(instance));
            if (doubleVector2.sumOfValues() > 0.0d) {
                doubleVector2.normalize();
                doubleVector2.scaleValues(ensembleMemberWeight);
                doubleVector.addValues(doubleVector2);
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = new Measurement[1];
        measurementArr[0] = new Measurement("ensemble size", this.ensemble != null ? this.ensemble.length : 0.0d);
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.ensemble.clone();
    }
}
