package moa.evaluation;

import com.github.javacliparser.FlagOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceImpl;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Iterator;
import java.util.TreeSet;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;
import weka.classifiers.evaluation.ThresholdCurve;

/* loaded from: input_file:lib/moa.jar:moa/evaluation/BasicAUCImbalancedPerformanceEvaluator.class */
public class BasicAUCImbalancedPerformanceEvaluator extends AbstractOptionHandler implements ClassificationPerformanceEvaluator {
    private static final long serialVersionUID = 1;
    public FlagOption calculateAUC = new FlagOption("calculateAUC", 'a', "Determines whether AUC should be calculated. To calculate AUC, predictions need to be remembered, therefore setting this option for large streams can cause substantial memory usage.");
    protected double totalObservedInstances = 0.0d;
    private Estimator aucEstimator;
    private SimpleEstimator weightMajorityClassifier;
    protected int numClasses;

    /* loaded from: input_file:lib/moa.jar:moa/evaluation/BasicAUCImbalancedPerformanceEvaluator$Estimator.class */
    public class Estimator {
        protected TreeSet<Score> sortedScores = new TreeSet<>();
        protected int position;
        protected double numPos;
        protected double numNeg;
        protected double correctPredictions;
        protected double correctPositivePredictions;
        protected double[] columnKappa;
        protected double[] rowKappa;
        protected boolean calculateAuc;

        /* loaded from: input_file:lib/moa.jar:moa/evaluation/BasicAUCImbalancedPerformanceEvaluator$Estimator$Score.class */
        public class Score implements Comparable<Score> {
            protected double value;
            protected int position;
            protected boolean isPositive;

            public Score(double d, int i, boolean z) {
                this.value = d;
                this.isPositive = z;
                this.position = i;
            }

            @Override // java.lang.Comparable
            public int compareTo(Score score) {
                if (score.value < this.value) {
                    return -1;
                }
                if (score.value > this.value) {
                    return 1;
                }
                if (!score.isPositive && this.isPositive) {
                    return -1;
                }
                if (score.isPositive && !this.isPositive) {
                    return 1;
                }
                if (score.position > this.position) {
                    return -1;
                }
                return score.position < this.position ? 1 : 0;
            }

            public boolean equals(Object obj) {
                return (obj instanceof Score) && ((Score) obj).position == this.position;
            }
        }

        public Estimator(boolean z) {
            this.calculateAuc = z;
            this.rowKappa = new double[BasicAUCImbalancedPerformanceEvaluator.this.numClasses];
            this.columnKappa = new double[BasicAUCImbalancedPerformanceEvaluator.this.numClasses];
            for (int i = 0; i < BasicAUCImbalancedPerformanceEvaluator.this.numClasses; i++) {
                this.rowKappa[i] = 0.0d;
                this.columnKappa[i] = 0.0d;
            }
            this.position = 0;
            this.numPos = 0.0d;
            this.numNeg = 0.0d;
            this.correctPredictions = 0.0d;
            this.correctPositivePredictions = 0.0d;
        }

        public void add(double d, boolean z, boolean z2) {
            Score score = new Score(d, this.position, z);
            if (this.calculateAuc) {
                this.sortedScores.add(score);
            }
            this.correctPredictions += z2 ? 1.0d : 0.0d;
            this.correctPositivePredictions += (z2 && z) ? 1.0d : 0.0d;
            int i = z ? 1 : 0;
            int abs = z2 ? i : Math.abs(i - 1);
            double[] dArr = this.rowKappa;
            dArr[abs] = dArr[abs] + 1.0d;
            double[] dArr2 = this.columnKappa;
            dArr2[i] = dArr2[i] + 1.0d;
            if (score.isPositive) {
                this.numPos += 1.0d;
            } else {
                this.numNeg += 1.0d;
            }
            this.position++;
        }

        public double getAUC() {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = Double.MAX_VALUE;
            if (!this.calculateAuc) {
                return -1.0d;
            }
            if (this.numPos == 0.0d || this.numNeg == 0.0d) {
                return 1.0d;
            }
            Iterator<Score> it = this.sortedScores.iterator();
            while (it.hasNext()) {
                Score next = it.next();
                if (next.isPositive) {
                    if (next.value != d4) {
                        d3 = d2;
                        d4 = next.value;
                    }
                    d2 += 1.0d;
                } else {
                    d = next.value == d4 ? d + ((d2 + d3) / 2.0d) : d + d2;
                }
            }
            return d / (this.numPos * this.numNeg);
        }

        public double getScoredAUC() {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            double d7 = Double.MAX_VALUE;
            double d8 = Double.MAX_VALUE;
            if (!this.calculateAuc) {
                return -1.0d;
            }
            if (this.numPos == 0.0d || this.numNeg == 0.0d) {
                return 1.0d;
            }
            Iterator<Score> it = this.sortedScores.iterator();
            while (it.hasNext()) {
                Score next = it.next();
                if (next.isPositive) {
                    if (next.value != d7) {
                        d6 = d5;
                        d7 = next.value;
                    }
                    d5 += next.value;
                    d = next.value == d8 ? d + ((d3 + d4) / 2.0d) : d + d3;
                } else {
                    if (next.value != d8) {
                        d4 = d3;
                        d8 = next.value;
                    }
                    d3 += next.value;
                    d2 = next.value == d7 ? d2 + ((d5 + d6) / 2.0d) : d2 + d5;
                }
            }
            return (d2 / (this.numPos * this.numNeg)) - (((this.numPos * d3) - d) / (this.numPos * this.numNeg));
        }

        public double getRatio() {
            if (this.numNeg == 0.0d) {
                return Double.MAX_VALUE;
            }
            return this.numPos / this.numNeg;
        }

        public double getAccuracy() {
            if (BasicAUCImbalancedPerformanceEvaluator.this.totalObservedInstances > 0.0d) {
                return this.correctPredictions / BasicAUCImbalancedPerformanceEvaluator.this.totalObservedInstances;
            }
            return 0.0d;
        }

        public double getKappa() {
            double accuracy = getAccuracy();
            double d = 0.0d;
            for (int i = 0; i < BasicAUCImbalancedPerformanceEvaluator.this.numClasses; i++) {
                d += (this.rowKappa[i] / BasicAUCImbalancedPerformanceEvaluator.this.totalObservedInstances) * (this.columnKappa[i] / BasicAUCImbalancedPerformanceEvaluator.this.totalObservedInstances);
            }
            return (accuracy - d) / (1.0d - d);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double getKappaM() {
            double accuracy = getAccuracy();
            double estimation = BasicAUCImbalancedPerformanceEvaluator.this.weightMajorityClassifier.estimation();
            return (accuracy - estimation) / (1.0d - estimation);
        }

        public double getGMean() {
            return Math.sqrt((this.correctPositivePredictions / this.numPos) * ((this.correctPredictions - this.correctPositivePredictions) / this.numNeg));
        }

        public double getRecall() {
            return this.correctPositivePredictions / this.numPos;
        }
    }

    /* loaded from: input_file:lib/moa.jar:moa/evaluation/BasicAUCImbalancedPerformanceEvaluator$SimpleEstimator.class */
    public class SimpleEstimator {
        protected double len;
        protected double sum;

        public SimpleEstimator() {
        }

        public void add(double d) {
            this.sum += d;
            this.len += 1.0d;
        }

        public double estimation() {
            return this.sum / this.len;
        }
    }

    @Override // moa.evaluation.LearningPerformanceEvaluator
    public void reset() {
        reset(this.numClasses);
    }

    public void reset(int i) {
        if (i != 2) {
            throw new RuntimeException("Too many classes (" + i + "). AUC evaluation can be performed only for two-class problems!");
        }
        this.numClasses = i;
        this.aucEstimator = new Estimator(this.calculateAUC.isSet());
        this.weightMajorityClassifier = new SimpleEstimator();
        this.totalObservedInstances = 0.0d;
    }

    @Override // moa.evaluation.LearningPerformanceEvaluator
    public void addResult(Example<Instance> example, double[] dArr) {
        InstanceImpl instanceImpl = (InstanceImpl) example.getData();
        double weight = instanceImpl.weight();
        if (instanceImpl.classIsMissing()) {
            return;
        }
        int classValue = (int) instanceImpl.classValue();
        if (weight > 0.0d) {
            if (this.totalObservedInstances == 0.0d) {
                reset(instanceImpl.dataset().numClasses());
            }
            this.totalObservedInstances += 1.0d;
            Double valueOf = Double.valueOf(0.0d);
            if (dArr.length == 2) {
                valueOf = Double.valueOf(dArr[1] / (dArr[0] + dArr[1]));
            }
            if (valueOf.isNaN()) {
                valueOf = Double.valueOf(0.0d);
            }
            this.aucEstimator.add(valueOf.doubleValue(), classValue == 1, Utils.maxIndex(dArr) == classValue);
            this.weightMajorityClassifier.add(((this.aucEstimator.getRatio() > 1.0d ? 1 : (this.aucEstimator.getRatio() == 1.0d ? 0 : -1)) <= 0 ? 0 : 1) == classValue ? weight : 0.0d);
        }
    }

    @Override // moa.evaluation.LearningPerformanceEvaluator
    public Measurement[] getPerformanceMeasurements() {
        return new Measurement[]{new Measurement("classified instances", this.totalObservedInstances), new Measurement("AUC", this.aucEstimator.getAUC()), new Measurement("sAUC", this.aucEstimator.getScoredAUC()), new Measurement("Accuracy", this.aucEstimator.getAccuracy()), new Measurement("Kappa", this.aucEstimator.getKappa()), new Measurement("Pos/Neg ratio", this.aucEstimator.getRatio()), new Measurement("G-Mean", this.aucEstimator.getGMean()), new Measurement(ThresholdCurve.RECALL_NAME, this.aucEstimator.getRecall()), new Measurement("KappaM", this.aucEstimator.getKappaM())};
    }

    @Override // moa.MOAObject
    public void getDescription(StringBuilder sb, int i) {
        Measurement.getMeasurementsDescription(getPerformanceMeasurements(), sb, i);
    }

    @Override // moa.options.AbstractOptionHandler
    public void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
    }

    public Estimator getAucEstimator() {
        return this.aucEstimator;
    }

    @Override // moa.evaluation.LearningPerformanceEvaluator
    public void addResult(Example<Instance> example, Prediction prediction) {
        throw new RuntimeException("Designed for scoring classifiers");
    }
}
