package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/OCBoost.class */
public class OCBoost extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models to boost.", 10, 1, Integer.MAX_VALUE);
    public FloatOption smoothingOption = new FloatOption("smoothingParameter", 'e', "Smoothing parameter.", 0.5d, 0.0d, 100.0d);
    protected Classifier[] ensemble;
    protected double[] alpha;
    protected double[] alphainc;
    protected double[] pipos;
    protected double[] pineg;
    protected double[][] wpos;
    protected double[][] wneg;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Online Coordinate boosting for two classes evolving data streams.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.ensemble = new Classifier[this.ensembleSizeOption.getValue()];
        this.alpha = new double[this.ensemble.length];
        this.alphainc = new double[this.ensemble.length];
        this.pipos = new double[this.ensemble.length];
        this.pineg = new double[this.ensemble.length];
        this.wpos = new double[this.ensemble.length][this.ensemble.length];
        this.wneg = new double[this.ensemble.length][this.ensemble.length];
        Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        classifier.resetLearning();
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i] = classifier.copy();
            this.alpha[i] = 0.0d;
            this.alphainc[i] = 0.0d;
            for (int i2 = 0; i2 < this.ensemble.length; i2++) {
                this.wpos[i][i2] = this.smoothingOption.getValue();
                this.wneg[i][i2] = this.smoothingOption.getValue();
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        double d = 1.0d;
        int[] iArr = new int[this.ensemble.length];
        for (int i = 0; i < this.ensemble.length; i++) {
            this.pipos[i] = 1.0d;
            this.pineg[i] = 1.0d;
            iArr[i] = -1;
            if (this.ensemble[i].correctlyClassifies(instance)) {
                iArr[i] = 1;
            }
            for (int i2 = 0; i2 <= i - 1; i2++) {
                double[] dArr = this.pipos;
                int i3 = i;
                dArr[i3] = dArr[i3] * (((this.wpos[i][i2] / this.wpos[i][i]) * Math.exp(-this.alphainc[i2])) + ((1.0d - (this.wpos[i][i2] / this.wpos[i][i])) * Math.exp(this.alphainc[i2])));
                double[] dArr2 = this.pineg;
                int i4 = i;
                dArr2[i4] = dArr2[i4] * (((this.wneg[i][i2] / this.wneg[i][i]) * Math.exp(-this.alphainc[i2])) + ((1.0d - (this.wneg[i][i2] / this.wneg[i][i])) * Math.exp(this.alphainc[i2])));
            }
            for (int i5 = 0; i5 <= i; i5++) {
                this.wpos[i][i5] = (this.wpos[i][i5] * this.pipos[i]) + (d * (iArr[i5] == 1 ? 1 : 0) * (iArr[i] == 1 ? 1 : 0));
                this.wneg[i][i5] = (this.wneg[i][i5] * this.pineg[i]) + (d * (iArr[i5] == -1 ? 1 : 0) * (iArr[i] == -1 ? 1 : 0));
            }
            this.alphainc[i] = -this.alpha[i];
            this.alpha[i] = 0.5d * Math.log(this.wpos[i][i] / this.wneg[i][i]);
            double[] dArr3 = this.alphainc;
            int i6 = i;
            dArr3[i6] = dArr3[i6] + this.alpha[i];
            d *= Math.exp((-this.alpha[i]) * iArr[i]);
            if (d > 0.0d) {
                Instance copy = instance.copy();
                copy.setWeight(instance.weight() * d);
                this.ensemble[i].trainOnInstance(copy);
            }
        }
    }

    protected double getEnsembleMemberWeight(int i) {
        return this.alpha[i];
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        double[] dArr = new double[2];
        double d = 0.0d;
        for (int i = 0; i < this.ensemble.length; i++) {
            int maxIndex = Utils.maxIndex(this.ensemble[i].getVotesForInstance(instance));
            if (maxIndex == 0) {
                maxIndex = -1;
            }
            d += maxIndex * getEnsembleMemberWeight(i);
        }
        dArr[0] = 0.0d;
        dArr[1] = 0.0d;
        dArr[d > 0.0d ? (char) 1 : (char) 0] = 1.0d;
        return dArr;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = new Measurement[1];
        measurementArr[0] = new Measurement("ensemble size", this.ensemble != null ? this.ensemble.length : 0.0d);
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.ensemble.clone();
    }
}
