package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/BOLE.class */
public class BOLE extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "drift.SingleClassifierDrift -l trees.HoeffdingTree -d (DDM -n 7 -w 1.2 -o 1.95)");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
    public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p', "Boost with weights only; no poisson.");
    public FlagOption breakVotesOption = new FlagOption("breakVotes", 'b', "Break Votes? unchecked=no, checked=yes");
    public FloatOption errorBoundOption = new FloatOption("errorBound", 'e', "Error bound percentage for allowing experts to vote.", 0.5d, 0.1d, 1.0d);
    public FloatOption weightShiftOption = new FloatOption("weightShift", 'w', "Weight shift associated with the error bound.", 0.0d, 0.0d, 5.0d);
    private double memberWeight;
    private double key_acc;
    private int key_position;
    private int i;
    private int j;
    private int maxAcc;
    private int minAcc;
    private int pos;
    private double lambda_d;
    private double k;
    private boolean correct;
    private boolean okay;
    private double em;
    private double Bm;
    protected Classifier[] ensemble;
    protected int[] orderPosition;
    protected double[] scms;
    protected double[] swms;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Boosting-like Online Learning Ensemble (BOLE)";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        this.orderPosition = new int[this.ensemble.length];
        Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        classifier.resetLearning();
        this.i = 0;
        while (this.i < this.ensemble.length) {
            this.ensemble[this.i] = classifier.copy();
            this.orderPosition[this.i] = this.i;
            this.i++;
        }
        this.scms = new double[this.ensemble.length];
        this.swms = new double[this.ensemble.length];
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        double[] dArr = new double[this.ensemble.length];
        this.i = 0;
        while (this.i < this.ensemble.length) {
            dArr[this.i] = this.scms[this.orderPosition[this.i]] + this.swms[this.orderPosition[this.i]];
            if (dArr[this.i] != 0.0d) {
                dArr[this.i] = this.scms[this.orderPosition[this.i]] / dArr[this.i];
            }
            this.i++;
        }
        this.i = 1;
        while (this.i < this.ensemble.length) {
            this.key_position = this.orderPosition[this.i];
            this.key_acc = dArr[this.i];
            this.j = this.i - 1;
            while (this.j >= 0 && dArr[this.j] < this.key_acc) {
                this.orderPosition[this.j + 1] = this.orderPosition[this.j];
                dArr[this.j + 1] = dArr[this.j];
                this.j--;
            }
            this.orderPosition[this.j + 1] = this.key_position;
            dArr[this.j + 1] = this.key_acc;
            this.i++;
        }
        this.correct = false;
        this.maxAcc = 0;
        this.minAcc = this.ensemble.length - 1;
        this.lambda_d = 1.0d;
        this.i = 0;
        while (this.i < this.ensemble.length) {
            if (this.correct) {
                this.pos = this.orderPosition[this.maxAcc];
                this.maxAcc++;
            } else {
                this.pos = this.orderPosition[this.minAcc];
                this.minAcc--;
            }
            if (this.pureBoostOption.isSet()) {
                this.k = this.lambda_d;
            } else {
                this.k = MiscUtils.poisson(this.lambda_d, this.classifierRandom);
            }
            if (this.k > 0.0d) {
                Instance copy = instance.copy();
                copy.setWeight(instance.weight() * this.k);
                this.ensemble[this.pos].trainOnInstance(copy);
            }
            if (this.ensemble[this.pos].correctlyClassifies(instance)) {
                double[] dArr2 = this.scms;
                int i = this.pos;
                dArr2[i] = dArr2[i] + this.lambda_d;
                this.lambda_d *= this.trainingWeightSeenByModel / (2.0d * this.scms[this.pos]);
                this.correct = true;
            } else {
                double[] dArr3 = this.swms;
                int i2 = this.pos;
                dArr3[i2] = dArr3[i2] + this.lambda_d;
                this.lambda_d *= this.trainingWeightSeenByModel / (2.0d * this.swms[this.pos]);
                this.correct = false;
            }
            this.i++;
        }
    }

    protected double getEnsembleMemberWeight(int i) {
        if (this.scms[i] > 0.0d && this.swms[i] > 0.0d) {
            this.em = this.swms[i] / (this.scms[i] + this.swms[i]);
            if (this.em <= this.errorBoundOption.getValue()) {
                this.Bm = this.em / (1.0d - this.em);
                this.okay = true;
                return Math.log(1.0d / this.Bm);
            }
        }
        this.okay = false;
        return 0.0d;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        DoubleVector doubleVector = new DoubleVector();
        this.i = 0;
        while (this.i < this.ensemble.length) {
            this.memberWeight = getEnsembleMemberWeight(this.i) + this.weightShiftOption.getValue();
            if (this.okay) {
                DoubleVector doubleVector2 = new DoubleVector(this.ensemble[this.i].getVotesForInstance(instance));
                if (doubleVector2.sumOfValues() > 0.0d) {
                    doubleVector2.normalize();
                    doubleVector2.scaleValues(this.memberWeight);
                    doubleVector.addValues(doubleVector2);
                }
            } else if (this.breakVotesOption.isSet()) {
                break;
            }
            this.i++;
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = new Measurement[1];
        measurementArr[0] = new Measurement("ensemble size", this.ensemble != null ? this.ensemble.length : 0.0d);
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.ensemble.clone();
    }
}
