package moa.classifiers.meta;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.ClassOption;
import moa.tasks.TaskMonitor;
import org.jfree.chart.axis.ValueAxis;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AccuracyWeightedEnsemble.class */
public class AccuracyWeightedEnsemble extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    protected static Comparator<double[]> weightComparator = new ClassifierWeightComparator();
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", Classifier.class, "trees.HoeffdingTree -l NB -e 1000 -g 100 -c 0.01");
    public FloatOption memberCountOption = new FloatOption("memberCount", 'n', "The maximum number of classifier in an ensemble.", 15.0d, 1.0d, 2.147483647E9d);
    public FloatOption storedCountOption = new FloatOption("storedCount", 'r', "The maximum number of classifiers to store and choose from when creating an ensemble.", 30.0d, 1.0d, 2.147483647E9d);
    public IntOption chunkSizeOption = new IntOption("chunkSize", 'c', "The chunk size used for classifier creation and evaluation.", ValueAxis.MAXIMUM_TICK_COUNT, 1, Integer.MAX_VALUE);
    public IntOption numFoldsOption = new IntOption("numFolds", 'f', "Number of cross-validation folds for candidate classifier testing.", 10, 1, Integer.MAX_VALUE);
    protected long[] classDistributions;
    protected Classifier[] ensemble;
    protected Classifier[] storedLearners;
    protected double[] ensembleWeights;
    protected double[][] storedWeights;
    protected int processedInstances;
    protected int chunkSize;
    protected int numFolds;
    protected int maxMemberCount;
    protected int maxStoredCount;
    protected Classifier candidateClassifier;
    protected Instances currentChunk;

    /* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/AccuracyWeightedEnsemble$ClassifierWeightComparator.class */
    private static final class ClassifierWeightComparator implements Comparator<double[]> {
        private ClassifierWeightComparator() {
        }

        @Override // java.util.Comparator
        public int compare(double[] dArr, double[] dArr2) {
            if (dArr[0] > dArr2[0]) {
                return 1;
            }
            return dArr[0] < dArr2[0] ? -1 : 0;
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Accuracy Weighted Ensemble classifier as proposed by Wang et al. in 'Mining concept-drifting data streams using ensemble classifiers', KDD 2003";
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler
    public void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
        this.maxMemberCount = (int) this.memberCountOption.getValue();
        this.maxStoredCount = (int) this.storedCountOption.getValue();
        if (this.maxMemberCount > this.maxStoredCount) {
            this.maxStoredCount = this.maxMemberCount;
        }
        this.chunkSize = this.chunkSizeOption.getValue();
        this.numFolds = this.numFoldsOption.getValue();
        this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption);
        this.candidateClassifier.resetLearning();
        super.prepareForUseImpl(taskMonitor, objectRepository);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.currentChunk = null;
        this.classDistributions = null;
        this.processedInstances = 0;
        this.ensemble = new Classifier[0];
        this.storedLearners = new Classifier[0];
        this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption);
        this.candidateClassifier.resetLearning();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        initVariables();
        long[] jArr = this.classDistributions;
        int classValue = (int) instance.classValue();
        jArr[classValue] = jArr[classValue] + 1;
        this.currentChunk.add(instance);
        this.processedInstances++;
        if (this.processedInstances % this.chunkSize == 0) {
            processChunk();
        }
    }

    private void initVariables() {
        if (this.currentChunk == null) {
            this.currentChunk = new Instances(getModelContext());
        }
        if (this.classDistributions == null) {
            this.classDistributions = new long[getModelContext().classAttribute().numValues()];
            for (int i = 0; i < this.classDistributions.length; i++) {
                this.classDistributions[i] = 0;
            }
        }
    }

    protected void processChunk() {
        double computeCandidateWeight = computeCandidateWeight(this.candidateClassifier, this.currentChunk, this.numFolds);
        for (int i = 0; i < this.storedLearners.length; i++) {
            this.storedWeights[i][0] = computeWeight(this.storedLearners[(int) this.storedWeights[i][1]], this.currentChunk);
        }
        if (this.storedLearners.length < this.maxStoredCount) {
            for (int i2 = 0; i2 < this.chunkSize; i2++) {
                this.candidateClassifier.trainOnInstance(this.currentChunk.instance(i2));
            }
            addToStored(this.candidateClassifier, computeCandidateWeight);
        } else {
            Arrays.sort(this.storedWeights, weightComparator);
            if (this.storedWeights[0][0] < computeCandidateWeight) {
                for (int i3 = 0; i3 < this.chunkSize; i3++) {
                    this.candidateClassifier.trainOnInstance(this.currentChunk.instance(i3));
                }
                this.storedWeights[0][0] = computeCandidateWeight;
                this.storedLearners[(int) this.storedWeights[0][1]] = this.candidateClassifier.copy();
            }
        }
        int min = Math.min(this.storedLearners.length, this.maxMemberCount);
        this.ensemble = new Classifier[min];
        this.ensembleWeights = new double[min];
        Arrays.sort(this.storedWeights, weightComparator);
        int length = this.storedLearners.length;
        for (int i4 = 0; i4 < min; i4++) {
            this.ensembleWeights[i4] = this.storedWeights[(length - i4) - 1][0];
            this.ensemble[i4] = this.storedLearners[(int) this.storedWeights[(length - i4) - 1][1]];
        }
        this.classDistributions = null;
        this.currentChunk = null;
        this.candidateClassifier = (Classifier) getPreparedClassOption(this.learnerOption);
        this.candidateClassifier.resetLearning();
    }

    protected double computeCandidateWeight(Classifier classifier, Instances instances, int i) {
        double d = 0.0d;
        Random random = new Random(1L);
        Instances instances2 = new Instances(instances);
        instances2.randomize(random);
        if (instances2.classAttribute().isNominal()) {
            instances2.stratify(i);
        }
        for (int i2 = 0; i2 < i; i2++) {
            Instances trainCV = instances2.trainCV(i, i2, random);
            Instances testCV = instances2.testCV(i, i2);
            Classifier copy = classifier.copy();
            for (int i3 = 0; i3 < trainCV.numInstances(); i3++) {
                copy.trainOnInstance(trainCV.instance(i3));
            }
            d += computeWeight(copy, testCV);
        }
        double d2 = d / i;
        if (Double.isInfinite(d2)) {
            return Double.MAX_VALUE;
        }
        return d2;
    }

    protected double computeWeight(Classifier classifier, Instances instances) {
        double d = 0.0d;
        for (int i = 0; i < instances.numInstances(); i++) {
            try {
                double d2 = 0.0d;
                for (double d3 : classifier.getVotesForInstance(instances.instance(i))) {
                    d2 += d3;
                }
                if (d2 > 0.0d) {
                    double d4 = classifier.getVotesForInstance(instances.instance(i))[(int) instances.instance(i).classValue()] / d2;
                    d += (1.0d - d4) * (1.0d - d4);
                } else {
                    d += 1.0d;
                }
            } catch (Exception e) {
                d += 1.0d;
            }
        }
        return Math.max(computeMseR() - (d / this.chunkSize), 0.0d);
    }

    protected double computeMseR() {
        double d = 0.0d;
        for (int i = 0; i < this.classDistributions.length; i++) {
            double d2 = this.classDistributions[i] / this.chunkSize;
            d += d2 * (1.0d - d2) * (1.0d - d2);
        }
        return d;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        DoubleVector doubleVector = new DoubleVector();
        if (this.trainingWeightSeenByModel > 0.0d) {
            for (int i = 0; i < this.ensemble.length; i++) {
                if (this.ensembleWeights[i] > 0.0d) {
                    DoubleVector doubleVector2 = new DoubleVector(this.ensemble[i].getVotesForInstance(instance));
                    if (doubleVector2.sumOfValues() > 0.0d) {
                        doubleVector2.normalize();
                        doubleVector2.scaleValues(this.ensembleWeights[i] / ((1.0d * this.ensemble.length) + 1.0d));
                        doubleVector.addValues(doubleVector2);
                    }
                }
            }
        }
        doubleVector.normalize();
        return doubleVector.getArrayRef();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr = new Measurement[this.maxStoredCount];
        for (int i = 0; i < this.maxMemberCount; i++) {
            measurementArr[i] = new Measurement("Member weight " + (i + 1), -1.0d);
        }
        for (int i2 = this.maxMemberCount; i2 < this.maxStoredCount; i2++) {
            measurementArr[i2] = new Measurement("Stored member weight " + (i2 + 1), -1.0d);
        }
        if (this.storedWeights != null) {
            int length = this.storedWeights.length;
            for (int i3 = 0; i3 < length; i3++) {
                if (i3 < this.ensemble.length) {
                    measurementArr[i3] = new Measurement("Member weight " + (i3 + 1), this.storedWeights[(length - i3) - 1][0]);
                } else {
                    measurementArr[i3] = new Measurement("Stored member weight " + (i3 + 1), this.storedWeights[(length - i3) - 1][0]);
                }
            }
        }
        return measurementArr;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public Classifier[] getSubClassifiers() {
        return (Classifier[]) this.ensemble.clone();
    }

    protected Classifier addToStored(Classifier classifier, double d) {
        Classifier classifier2 = null;
        Classifier[] classifierArr = new Classifier[this.storedLearners.length + 1];
        double[][] dArr = new double[classifierArr.length][2];
        for (int i = 0; i < classifierArr.length; i++) {
            if (i < this.storedLearners.length) {
                classifierArr[i] = this.storedLearners[i];
                dArr[i][0] = this.storedWeights[i][0];
                dArr[i][1] = this.storedWeights[i][1];
            } else {
                Classifier copy = classifier.copy();
                classifier2 = copy;
                classifierArr[i] = copy;
                dArr[i][0] = d;
                dArr[i][1] = i;
            }
        }
        this.storedLearners = classifierArr;
        this.storedWeights = dArr;
        return classifier2;
    }

    protected int removePoorestModelBytes() {
        int minIndex = Utils.minIndex(this.ensembleWeights);
        int measureByteSize = this.ensemble[minIndex].measureByteSize();
        discardModel(minIndex);
        return measureByteSize;
    }

    protected void discardModel(int i) {
        Classifier[] classifierArr = new Classifier[this.ensemble.length - 1];
        double[] dArr = new double[classifierArr.length];
        int i2 = 0;
        for (int i3 = 0; i3 < classifierArr.length; i3++) {
            if (i2 == i) {
                i2++;
            }
            classifierArr[i3] = this.ensemble[i2];
            dArr[i3] = this.ensembleWeights[i2];
            i2++;
        }
        this.ensemble = classifierArr;
        this.ensembleWeights = dArr;
    }
}
