package moa.classifiers.meta.imbalanced;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.core.Utils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/imbalanced/OnlineAdaC2.class */
public class OnlineAdaC2 extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "meta.AdaptiveRandomForest");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The size of the ensemble.", 10, 1, Integer.MAX_VALUE);
    public FloatOption costPositiveOption = new FloatOption("costPositive", 'p', "The cost of misclassifying a positive sample.", 1.0d, 0.1d, 1.0d);
    public FloatOption costNegativeOption = new FloatOption("costNegative", 'n', "The cost of misclassifying a negative sample.", 0.1d, 0.1d, 1.0d);
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'd', "Should use ADWIN as drift detector?");
    protected Classifier baseLearner;
    protected int nEstimators;
    protected double costPositive;
    protected double costNegative;
    protected boolean driftDetection;
    protected ArrayList<Classifier> ensemble;
    protected ArrayList<ADWIN> adwinEnsemble;
    protected ArrayList<Double> lambdaTP;
    protected ArrayList<Double> lambdaTN;
    protected ArrayList<Double> lambdaFP;
    protected ArrayList<Double> lambdaFN;
    protected ArrayList<Double> lambdaSum;
    protected ArrayList<Double> wAcc;
    protected ArrayList<Double> wErr;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "OnlineAdaC2 is the adaptation of the ensemble learner to data streams from B. Wang and J. Pineau";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        this.baseLearner.resetLearning();
        this.nEstimators = this.ensembleSizeOption.getValue();
        this.costPositive = this.costPositiveOption.getValue();
        this.costNegative = this.costNegativeOption.getValue();
        this.driftDetection = !this.disableDriftDetectionOption.isSet();
        this.ensemble = new ArrayList<>();
        if (this.driftDetection) {
            this.adwinEnsemble = new ArrayList<>();
        }
        this.lambdaTP = new ArrayList<>();
        this.lambdaTN = new ArrayList<>();
        this.lambdaFP = new ArrayList<>();
        this.lambdaFN = new ArrayList<>();
        this.lambdaSum = new ArrayList<>();
        this.wAcc = new ArrayList<>();
        this.wErr = new ArrayList<>();
        for (int i = 0; i < this.nEstimators; i++) {
            this.ensemble.add(this.baseLearner.copy());
            if (this.driftDetection) {
                this.adwinEnsemble.add(new ADWIN());
            }
            this.lambdaTP.add(Double.valueOf(0.0d));
            this.lambdaTN.add(Double.valueOf(0.0d));
            this.lambdaFP.add(Double.valueOf(0.0d));
            this.lambdaFN.add(Double.valueOf(0.0d));
            this.lambdaSum.add(Double.valueOf(0.0d));
            this.wAcc.add(Double.valueOf(0.0d));
            this.wErr.add(Double.valueOf(0.0d));
        }
        this.classifierRandom = new Random(this.randomSeed);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.ensemble.isEmpty()) {
            resetLearningImpl();
        }
        adjustEnsembleSize(instance.numClasses());
        double d = 1.0d;
        boolean z = false;
        for (int i = 0; i < this.ensemble.size(); i++) {
            this.lambdaSum.set(i, Double.valueOf(this.lambdaSum.get(i).doubleValue() + d));
            double poisson = MiscUtils.poisson(d, this.classifierRandom);
            if (poisson > 0.0d) {
                for (int i2 = 0; i2 < poisson; i2++) {
                    this.ensemble.get(i).trainOnInstance(instance);
                }
                if (Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == 1.0d && instance.classValue() == 1.0d) {
                    this.lambdaTP.set(i, Double.valueOf(this.lambdaTP.get(i).doubleValue() + (this.costPositive * d)));
                    this.wAcc.set(i, Double.valueOf((this.lambdaTP.get(i).doubleValue() + this.lambdaTN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    this.wErr.set(i, Double.valueOf((this.lambdaFP.get(i).doubleValue() + this.lambdaFN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    d = (this.costPositive * d) / (2.0d * this.wAcc.get(i).doubleValue());
                } else if (Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == 0.0d && instance.classValue() == 0.0d) {
                    this.lambdaTN.set(i, Double.valueOf(this.lambdaTN.get(i).doubleValue() + (this.costNegative * d)));
                    this.wAcc.set(i, Double.valueOf((this.lambdaTP.get(i).doubleValue() + this.lambdaTN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    this.wErr.set(i, Double.valueOf((this.lambdaFP.get(i).doubleValue() + this.lambdaFN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    d = (this.costNegative * d) / (2.0d * this.wAcc.get(i).doubleValue());
                } else if (Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == 0.0d && instance.classValue() == 1.0d) {
                    this.lambdaFN.set(i, Double.valueOf(this.lambdaFN.get(i).doubleValue() + (this.costPositive * d)));
                    this.wAcc.set(i, Double.valueOf((this.lambdaTP.get(i).doubleValue() + this.lambdaTN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    this.wErr.set(i, Double.valueOf((this.lambdaFP.get(i).doubleValue() + this.lambdaFN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    d = (this.costPositive * d) / (2.0d * this.wErr.get(i).doubleValue());
                } else if (Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance)) == 1.0d && instance.classValue() == 0.0d) {
                    this.lambdaFP.set(i, Double.valueOf(this.lambdaFP.get(i).doubleValue() + (this.costNegative * d)));
                    this.wAcc.set(i, Double.valueOf((this.lambdaTP.get(i).doubleValue() + this.lambdaTN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    this.wErr.set(i, Double.valueOf((this.lambdaFP.get(i).doubleValue() + this.lambdaFN.get(i).doubleValue()) / this.lambdaSum.get(i).doubleValue()));
                    d = (this.costNegative * d) / (2.0d * this.wErr.get(i).doubleValue());
                }
            }
            if (this.driftDetection) {
                double maxIndex = Utils.maxIndex(this.ensemble.get(i).getVotesForInstance(instance));
                double estimation = this.adwinEnsemble.get(i).getEstimation();
                if (this.adwinEnsemble.get(i).setInput(maxIndex == instance.classValue() ? 1.0d : 0.0d) && this.adwinEnsemble.get(i).getEstimation() > estimation) {
                    z = true;
                }
            }
        }
        if (z && this.driftDetection) {
            double d2 = 0.0d;
            int i3 = -1;
            for (int i4 = 0; i4 < this.ensemble.size(); i4++) {
                if (d2 < this.adwinEnsemble.get(i4).getEstimation()) {
                    d2 = this.adwinEnsemble.get(i4).getEstimation();
                    i3 = i4;
                }
            }
            if (i3 != -1) {
                this.ensemble.get(i3).resetLearning();
                this.adwinEnsemble.set(i3, new ADWIN());
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        Instance copy = instance.copy();
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < this.ensemble.size(); i++) {
            DoubleVector doubleVector2 = new DoubleVector(this.ensemble.get(i).getVotesForInstance(copy));
            if (doubleVector2.sumOfValues() > 0.0d) {
                for (int i2 = 0; i2 < doubleVector2.numValues(); i2++) {
                    doubleVector2.setValue(i2, doubleVector2.getValue(i2) * Math.log(this.wAcc.get(i).doubleValue() / this.wErr.get(i).doubleValue()));
                }
                doubleVector2.normalize();
                doubleVector.addValues(doubleVector2);
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void adjustEnsembleSize(int i) {
        if (i > this.nEstimators) {
            for (int i2 = this.nEstimators; i2 < i; i2++) {
                this.ensemble.add(this.baseLearner.copy());
                this.nEstimators++;
                if (this.driftDetection) {
                    this.adwinEnsemble.add(new ADWIN());
                }
                this.lambdaTP.add(Double.valueOf(0.0d));
                this.lambdaTN.add(Double.valueOf(0.0d));
                this.lambdaFP.add(Double.valueOf(0.0d));
                this.lambdaFN.add(Double.valueOf(0.0d));
                this.lambdaSum.add(Double.valueOf(0.0d));
                this.wAcc.add(Double.valueOf(0.0d));
                this.wErr.add(Double.valueOf(0.0d));
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return getClass() == OnlineAdaC2.class ? new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE) : new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}
