package moa.classifiers.meta.imbalanced;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/imbalanced/OnlineRUSBoost.class */
public class OnlineRUSBoost extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "meta.AdaptiveRandomForest");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The size of the ensemble.", 10, 1, Integer.MAX_VALUE);
    public IntOption samplingRateOption = new IntOption("samplingRate", 'i', "The sampling rate of the positive instances.", 3, 1, 10);
    public MultiChoiceOption algorithmImplementationOption = new MultiChoiceOption("algorithmImplementation", 'a', "The implementation of RUSBoost to use.", new String[]{"Fixed class ration", "Fixed example distribution", "Fixed sampling rate"}, new String[]{"ClassRation", "ExampleDistribution", "SamplingRate"}, 0);
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'd', "Should use ADWIN as drift detector?");
    protected Classifier baseLearner;
    protected int nEstimators;
    protected int samplingRate;
    protected int algorithmImplementation;
    protected boolean driftDetection;
    protected ArrayList<Classifier> ensemble;
    protected ArrayList<ADWIN> adwinEnsemble;
    protected ArrayList<Double> lambdaSc;
    protected ArrayList<Double> lambdaPos;
    protected ArrayList<Double> lambdaNeg;
    protected ArrayList<Double> lambdaSw;
    protected ArrayList<Double> epsilon;
    protected double nPositive;
    protected double nNegative;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Online RUSBoost is the adaptation of the ensemble learner to data streams.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        this.baseLearner.resetLearning();
        this.nEstimators = this.ensembleSizeOption.getValue();
        this.samplingRate = this.samplingRateOption.getValue();
        this.algorithmImplementation = this.algorithmImplementationOption.getChosenIndex();
        this.driftDetection = !this.disableDriftDetectionOption.isSet();
        this.ensemble = new ArrayList<>();
        if (this.driftDetection) {
            this.adwinEnsemble = new ArrayList<>();
        }
        this.lambdaSc = new ArrayList<>();
        this.lambdaPos = new ArrayList<>();
        this.lambdaNeg = new ArrayList<>();
        this.lambdaSw = new ArrayList<>();
        this.epsilon = new ArrayList<>();
        for (int i = 0; i < this.nEstimators; i++) {
            this.ensemble.add(this.baseLearner.copy());
            if (this.driftDetection) {
                this.adwinEnsemble.add(new ADWIN());
            }
            this.lambdaSc.add(Double.valueOf(0.0d));
            this.lambdaPos.add(Double.valueOf(0.0d));
            this.lambdaNeg.add(Double.valueOf(0.0d));
            this.lambdaSw.add(Double.valueOf(0.0d));
            this.epsilon.add(Double.valueOf(0.0d));
        }
        this.nPositive = 0.0d;
        this.nNegative = 0.0d;
        this.classifierRandom = new Random(this.randomSeed);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.ensemble.isEmpty()) {
            resetLearningImpl();
        }
        adjustEnsembleSize(instance.numClasses());
        double d = 1.0d;
        boolean z = false;
        for (int i = 0; i < this.ensemble.size(); i++) {
            if (instance.classValue() == 1.0d) {
                this.lambdaPos.set(i, Double.valueOf(this.lambdaPos.get(i).doubleValue() + d));
                this.nPositive += 1.0d;
            } else {
                this.lambdaNeg.set(i, Double.valueOf(this.lambdaNeg.get(i).doubleValue() + d));
                this.nNegative += 1.0d;
            }
            double d2 = 1.0d;
            if (this.algorithmImplementation == 0) {
                if (instance.classValue() == 1.0d) {
                    if (this.nNegative != 0.0d) {
                        d2 = d * ((this.lambdaPos.get(i).doubleValue() + this.lambdaNeg.get(i).doubleValue()) / (this.lambdaPos.get(i).doubleValue() + (this.lambdaNeg.get(i).doubleValue() * (this.samplingRate * (this.nPositive / this.nNegative))))) * (((this.samplingRate + 1) * this.nPositive) / (this.nPositive + this.nNegative));
                    }
                } else if (this.nPositive != 0.0d) {
                    d2 = d * ((this.lambdaPos.get(i).doubleValue() + this.lambdaNeg.get(i).doubleValue()) / (this.lambdaPos.get(i).doubleValue() + (this.lambdaNeg.get(i).doubleValue() * (this.nNegative / (this.nPositive * this.samplingRate))))) * (((this.samplingRate + 1) * this.nPositive) / (this.nPositive + this.nNegative));
                }
            } else if (this.algorithmImplementation == 1) {
                d2 = instance.classValue() == 1.0d ? ((d * this.nPositive) / (this.nPositive + this.nNegative)) / (this.lambdaPos.get(i).doubleValue() / (this.lambdaPos.get(i).doubleValue() + this.lambdaNeg.get(i).doubleValue())) : (((d * this.samplingRate) * this.nPositive) / (this.nPositive + this.nNegative)) / (this.lambdaNeg.get(i).doubleValue() / (this.lambdaPos.get(i).doubleValue() + this.lambdaNeg.get(i).doubleValue()));
            } else if (this.algorithmImplementation == 2) {
                d2 = instance.classValue() == 1.0d ? d : d / this.samplingRate;
            }
            double poisson = MiscUtils.poisson(d2, this.classifierRandom);
            if (poisson > 0.0d) {
                for (int i2 = 0; i2 < poisson; i2++) {
                    this.ensemble.get(i).trainOnInstance(instance);
                }
            }
            if (Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == instance.classValue()) {
                this.lambdaSc.set(i, Double.valueOf(this.lambdaSc.get(i).doubleValue() + d));
                this.epsilon.set(i, Double.valueOf(this.lambdaSw.get(i).doubleValue() / (this.lambdaSc.get(i).doubleValue() + this.lambdaSw.get(i).doubleValue())));
                if (this.epsilon.get(i).doubleValue() != 1.0d) {
                    d /= 2.0d * (1.0d - this.epsilon.get(i).doubleValue());
                }
            } else {
                this.lambdaSw.set(i, Double.valueOf(this.lambdaSw.get(i).doubleValue() + d));
                this.epsilon.set(i, Double.valueOf(this.lambdaSw.get(i).doubleValue() / (this.lambdaSc.get(i).doubleValue() + this.lambdaSw.get(i).doubleValue())));
                if (this.epsilon.get(i).doubleValue() != 0.0d) {
                    d /= 2.0d * this.epsilon.get(i).doubleValue();
                }
            }
            if (this.driftDetection) {
                double maxIndex = Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance));
                double estimation = this.adwinEnsemble.get(i).getEstimation();
                if (this.adwinEnsemble.get(i).setInput(maxIndex == instance.classValue() ? 1.0d : 0.0d) && this.adwinEnsemble.get(i).getEstimation() > estimation) {
                    z = true;
                }
            }
        }
        if (z && this.driftDetection) {
            double d3 = 0.0d;
            int i3 = -1;
            for (int i4 = 0; i4 < this.ensemble.size(); i4++) {
                if (d3 < this.adwinEnsemble.get(i4).getEstimation()) {
                    d3 = this.adwinEnsemble.get(i4).getEstimation();
                    i3 = i4;
                }
            }
            if (i3 != -1) {
                this.ensemble.get(i3).resetLearning();
                this.adwinEnsemble.set(i3, new ADWIN());
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        Instance copy = instance.copy();
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < this.ensemble.size(); i++) {
            DoubleVector doubleVector2 = new DoubleVector(this.ensemble.get(i).getVotesForInstance(copy));
            if (doubleVector2.sumOfValues() > 0.0d) {
                for (int i2 = 0; i2 < doubleVector2.numValues(); i2++) {
                    doubleVector2.setValue(i2, doubleVector2.getValue(i2) * Math.log((1.0d - this.epsilon.get(i).doubleValue()) / this.epsilon.get(i).doubleValue()));
                }
                doubleVector2.normalize();
                doubleVector.addValues(doubleVector2);
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void adjustEnsembleSize(int i) {
        if (i > this.nEstimators) {
            for (int i2 = this.nEstimators; i2 < i; i2++) {
                this.ensemble.add(this.baseLearner.copy());
                this.nEstimators++;
                if (this.driftDetection) {
                    this.adwinEnsemble.add(new ADWIN());
                }
                this.lambdaSc.add(Double.valueOf(0.0d));
                this.lambdaPos.add(Double.valueOf(0.0d));
                this.lambdaNeg.add(Double.valueOf(0.0d));
                this.lambdaSw.add(Double.valueOf(0.0d));
                this.epsilon.add(Double.valueOf(0.0d));
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return getClass() == OnlineRUSBoost.class ? new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE) : new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}
