package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.ArrayList;
import java.util.Random;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.core.DoubleVector;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.evaluation.BasicClassificationPerformanceEvaluator;
import moa.options.ClassOption;
import org.apache.log4j.Priority;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/StreamingRandomPatches.class */
public class StreamingRandomPatches extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler {
    private static final long serialVersionUID = 1;
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train on instances.", Classifier.class, "trees.HoeffdingTree -g 50 -c 0.01");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models.", 100, 1, Integer.MAX_VALUE);
    public MultiChoiceOption subspaceModeOption = new MultiChoiceOption("subspaceMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", "Percentage (M * (m / 100))"}, new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 3);
    public IntOption subspaceSizeOption = new IntOption("subspaceSize", 'm', "# attributes per subset for each classifier. Negative values = totalAttributes - #attributes", 60, Priority.ALL_INT, Integer.MAX_VALUE);
    public MultiChoiceOption trainingMethodOption = new MultiChoiceOption("trainingMethod", 't', "The training method to use: Random Patches, Random Subspaces or Bagging.", new String[]{"Random Subspaces", "Resampling (bagging)", "Random Patches"}, new String[]{"RandomSubspaces", "Resampling", "RandomPatches"}, 2);
    public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0d, 1.0d, 3.4028234663852886E38d);
    public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5");
    public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4");
    public FlagOption disableWeightedVote = new FlagOption("disableWeightedVote", 'w', "Should use weighted voting?");
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'u', "Should use drift detection? If disabled, then the bkg learner is also disabled.");
    public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', "Should use bkg learner? If disabled, then trees are reset immediately.");
    public static final int TRAIN_RANDOM_SUBSPACES = 0;
    public static final int TRAIN_RESAMPLING = 1;
    public static final int TRAIN_RANDOM_PATCHES = 2;
    protected static final int FEATURES_M = 0;
    protected static final int FEATURES_SQRT = 1;
    protected static final int FEATURES_SQRT_INV = 2;
    protected static final int FEATURES_PERCENT = 3;
    protected StreamingRandomPatchesClassifier[] ensemble;
    protected long instancesSeen;
    protected ArrayList<ArrayList<Integer>> subspaces;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:lib/moa.jar:moa/classifiers/meta/StreamingRandomPatches$StreamingRandomPatchesClassifier.class */
    public class StreamingRandomPatchesClassifier {
        public int indexOriginal;
        public long createdOn;
        public Classifier classifier;
        public Instances subset;
        public int[] featureIndexes;
        public boolean disableBkgLearner;
        public boolean disableDriftDetector;
        protected ChangeDetector driftDetectionMethod;
        protected ChangeDetector warningDetectionMethod;
        protected ClassOption driftOption;
        protected ClassOption warningOption;
        public StreamingRandomPatchesClassifier bkgLearner;
        public boolean isBackgroundLearner;
        public BasicClassificationPerformanceEvaluator evaluator;
        public int numberOfDriftsDetected;
        public int numberOfWarningsDetected;
        public int numberOfDriftsInduced;
        public int numberOfWarningsInduced;

        private void init(int i, Classifier classifier, BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, boolean z3) {
            this.indexOriginal = i;
            this.createdOn = j;
            this.classifier = classifier;
            this.evaluator = basicClassificationPerformanceEvaluator;
            this.disableBkgLearner = z;
            this.disableDriftDetector = z2;
            if (!this.disableDriftDetector) {
                this.driftOption = classOption;
                this.driftDetectionMethod = ((ChangeDetector) StreamingRandomPatches.this.getPreparedClassOption(classOption)).copy();
            }
            if (!this.disableBkgLearner) {
                this.warningOption = classOption2;
                this.warningDetectionMethod = ((ChangeDetector) StreamingRandomPatches.this.getPreparedClassOption(classOption2)).copy();
            }
            this.numberOfDriftsInduced = 0;
            this.numberOfDriftsDetected = 0;
            this.numberOfWarningsInduced = 0;
            this.numberOfWarningsDetected = 0;
            this.isBackgroundLearner = z3;
        }

        public StreamingRandomPatchesClassifier(int i, Classifier classifier, BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, boolean z3) {
            init(i, classifier, basicClassificationPerformanceEvaluator, j, z, z2, classOption, classOption2, z3);
            this.featureIndexes = null;
            this.subset = null;
        }

        public StreamingRandomPatchesClassifier(int i, Classifier classifier, BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator, long j, boolean z, boolean z2, ClassOption classOption, ClassOption classOption2, ArrayList<Integer> arrayList, Instance instance, boolean z3) {
            init(i, classifier, basicClassificationPerformanceEvaluator, j, z, z2, classOption, classOption2, z3);
            this.featureIndexes = new int[arrayList.size()];
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                arrayList2.add(instance.attribute(arrayList.get(i2).intValue()));
                this.featureIndexes[i2] = arrayList.get(i2).intValue();
            }
            this.subset = new Instances("Subsets Candidate Instances", arrayList2, 100);
            this.subset.setClassIndex(this.subset.numAttributes() - 1);
            prepareRandomSubspaceInstance(instance, 1.0d);
        }

        public void prepareRandomSubspaceInstance(Instance instance, double d) {
            while (this.subset.numInstances() > 0) {
                this.subset.delete(0);
            }
            double[] dArr = new double[this.subset.numAttributes()];
            for (int i = 0; i < this.subset.numAttributes(); i++) {
                dArr[i] = instance.value(this.featureIndexes[i]);
            }
            dArr[dArr.length - 1] = instance.classValue();
            DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
            denseInstance.setWeight(d);
            denseInstance.setDataset(this.subset);
            this.subset.add(denseInstance);
        }

        private ArrayList<Integer> applySubsetResetStrategy(Instance instance, Random random) {
            if (this.subset == null) {
                return null;
            }
            ArrayList<Integer> arrayList = new ArrayList<>();
            for (int i = 0; i < instance.numAttributes(); i++) {
                arrayList.add(Integer.valueOf(i));
            }
            arrayList.remove(instance.classIndex());
            for (int i2 = 0; i2 < instance.numAttributes() - this.featureIndexes.length; i2++) {
                arrayList.remove(random.nextInt(arrayList.size()));
            }
            arrayList.add(Integer.valueOf(instance.classIndex()));
            return arrayList;
        }

        public void reset(Instance instance, long j, Random random) {
            if (!this.disableBkgLearner && this.bkgLearner != null) {
                this.classifier = this.bkgLearner.classifier;
                this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod;
                this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod;
                this.evaluator = this.bkgLearner.evaluator;
                this.evaluator.reset();
                this.createdOn = this.bkgLearner.createdOn;
                this.subset = this.bkgLearner.subset;
                this.featureIndexes = this.bkgLearner.featureIndexes;
                return;
            }
            this.classifier.resetLearning();
            this.evaluator.reset();
            this.createdOn = j;
            this.driftDetectionMethod = ((ChangeDetector) StreamingRandomPatches.this.getPreparedClassOption(this.driftOption)).copy();
            if (this.subset != null) {
                ArrayList<Integer> applySubsetResetStrategy = applySubsetResetStrategy(instance, random);
                for (int i = 0; i < applySubsetResetStrategy.size(); i++) {
                    this.featureIndexes[i] = applySubsetResetStrategy.get(i).intValue();
                }
                ArrayList arrayList = new ArrayList();
                for (int i2 = 0; i2 < this.featureIndexes.length; i2++) {
                    arrayList.add(instance.attribute(this.featureIndexes[i2]));
                }
                this.subset = new Instances("Subsets Candidate Instances", arrayList, 100);
                this.subset.setClassIndex(this.subset.numAttributes() - 1);
                prepareRandomSubspaceInstance(instance, 1.0d);
            }
        }

        public void trainOnInstance(Instance instance, double d, long j, Random random) {
            boolean correctlyClassifies;
            if (this.subset != null) {
                prepareRandomSubspaceInstance(instance, d);
                this.classifier.trainOnInstance(this.subset.get(0));
                correctlyClassifies = this.classifier.correctlyClassifies(this.subset.get(0));
                if (this.bkgLearner != null) {
                    this.bkgLearner.trainOnInstance(instance, d, j, random);
                }
            } else {
                Instance copy = instance.copy();
                copy.setWeight(instance.weight() * d);
                this.classifier.trainOnInstance(copy);
                correctlyClassifies = this.classifier.correctlyClassifies(instance);
                if (this.bkgLearner != null) {
                    this.bkgLearner.trainOnInstance(instance, d, j, random);
                }
            }
            if (this.disableDriftDetector || this.isBackgroundLearner) {
                return;
            }
            if (!this.disableBkgLearner) {
                this.warningDetectionMethod.input(correctlyClassifies ? 0.0d : 1.0d);
                if (this.warningDetectionMethod.getChange()) {
                    this.numberOfWarningsDetected++;
                    triggerWarning(instance, j, random);
                }
            }
            this.driftDetectionMethod.input(correctlyClassifies ? 0.0d : 1.0d);
            if (this.driftDetectionMethod.getChange()) {
                this.numberOfDriftsDetected++;
                reset(instance, j, random);
            }
        }

        public void triggerWarning(Instance instance, long j, Random random) {
            Classifier copy = this.classifier.copy();
            copy.resetLearning();
            BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator = (BasicClassificationPerformanceEvaluator) this.evaluator.copy();
            basicClassificationPerformanceEvaluator.reset();
            if (this.subset == null) {
                this.bkgLearner = new StreamingRandomPatchesClassifier(this.indexOriginal, copy, basicClassificationPerformanceEvaluator, j, this.disableBkgLearner, this.disableDriftDetector, this.driftOption, this.warningOption, true);
            } else {
                this.bkgLearner = new StreamingRandomPatchesClassifier(this.indexOriginal, copy, basicClassificationPerformanceEvaluator, j, this.disableBkgLearner, this.disableDriftDetector, this.driftOption, this.warningOption, applySubsetResetStrategy(instance, random), instance, true);
            }
            this.warningDetectionMethod = ((ChangeDetector) StreamingRandomPatches.this.getPreparedClassOption(this.warningOption)).copy();
        }

        public double[] getVotesForInstance(Instance instance) {
            if (this.subset == null) {
                return new DoubleVector(this.classifier.getVotesForInstance(instance)).getArrayRef();
            }
            prepareRandomSubspaceInstance(instance, 1.0d);
            return new DoubleVector(this.classifier.getVotesForInstance(this.subset.get(0))).getArrayRef();
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.instancesSeen = 0L;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        this.instancesSeen++;
        if (this.ensemble == null) {
            initEnsemble(instance);
        }
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i].evaluator.addResult(new InstanceExample(instance), new DoubleVector(this.ensemble[i].getVotesForInstance(instance)).getArrayRef());
            if (this.trainingMethodOption.getChosenIndex() == 0) {
                this.ensemble[i].trainOnInstance(instance, 1.0d, this.instancesSeen, this.classifierRandom);
            } else {
                int poisson = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom);
                if (poisson > 0) {
                    this.ensemble[i].trainOnInstance(instance, poisson, this.instancesSeen, this.classifierRandom);
                }
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        Instance copy = instance.copy();
        copy.setMissing(instance.classAttribute());
        copy.setClassValue(0.0d);
        if (this.ensemble == null) {
            initEnsemble(copy);
        }
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; i++) {
            DoubleVector doubleVector2 = new DoubleVector(this.ensemble[i].getVotesForInstance(copy));
            if (doubleVector2.sumOfValues() > 0.0d) {
                doubleVector2.normalize();
                double value = this.ensemble[i].evaluator.getPerformanceMeasurements()[1].getValue();
                if (!this.disableWeightedVote.isSet() && value > 0.0d) {
                    for (int i2 = 0; i2 < doubleVector2.numValues(); i2++) {
                        doubleVector2.setValue(i2, doubleVector2.getValue(i2) * value);
                    }
                }
                doubleVector.addValues(doubleVector2);
            }
        }
        return doubleVector.getArrayRef();
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void initEnsemble(Instance instance) {
        int value = this.ensembleSizeOption.getValue();
        this.ensemble = new StreamingRandomPatchesClassifier[value];
        BasicClassificationPerformanceEvaluator basicClassificationPerformanceEvaluator = new BasicClassificationPerformanceEvaluator();
        int value2 = this.subspaceSizeOption.getValue();
        if (this.trainingMethodOption.getChosenIndex() != 1) {
            int numAttributes = instance.numAttributes() - 1;
            switch (this.subspaceModeOption.getChosenIndex()) {
                case 1:
                    value2 = ((int) Math.round(Math.sqrt(numAttributes))) + 1;
                    break;
                case 2:
                    value2 = numAttributes - ((int) Math.round(Math.sqrt(numAttributes) + 1.0d));
                    break;
                case 3:
                    double d = value2 < 0 ? (100 + value2) / 100.0d : value2 / 100.0d;
                    value2 = (int) Math.round(numAttributes * d);
                    if (Math.round(numAttributes * d) < 2) {
                        value2 = ((int) Math.round(numAttributes * d)) + 1;
                        break;
                    }
                    break;
            }
            if (value2 < 0) {
                value2 = numAttributes + value2;
            }
            if (this.trainingMethodOption.getChosenIndex() == 0 || this.trainingMethodOption.getChosenIndex() == 2) {
                if (value2 == 0 || value2 >= numAttributes) {
                    this.trainingMethodOption.setChosenIndex(1);
                } else if (numAttributes <= 20 || value2 < 2) {
                    if (value2 == 1 && instance.numAttributes() > 2) {
                        value2 = 2;
                    }
                    this.subspaces = allKCombinations(value2, numAttributes);
                    int i = 0;
                    while (this.subspaces.size() < this.ensemble.length) {
                        int i2 = i == this.subspaces.size() ? 0 : i;
                        this.subspaces.add(new ArrayList<>(this.subspaces.get(i2)));
                        i = i2 + 1;
                    }
                } else {
                    this.subspaces = localRandomKCombinations(value2, numAttributes, this.ensembleSizeOption.getValue(), this.classifierRandom);
                }
            }
        }
        Classifier classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption);
        classifier.resetLearning();
        for (int i3 = 0; i3 < value; i3++) {
            switch (this.trainingMethodOption.getChosenIndex()) {
                case 0:
                case 2:
                    int nextInt = this.classifierRandom.nextInt(this.subspaces.size());
                    ArrayList<Integer> arrayList = this.subspaces.get(nextInt);
                    arrayList.add(Integer.valueOf(instance.classIndex()));
                    this.ensemble[i3] = new StreamingRandomPatchesClassifier(i3, classifier.copy(), (BasicClassificationPerformanceEvaluator) basicClassificationPerformanceEvaluator.copy(), this.instancesSeen, this.disableBackgroundLearnerOption.isSet(), this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, arrayList, instance, false);
                    this.subspaces.remove(nextInt);
                    break;
                case 1:
                    this.ensemble[i3] = new StreamingRandomPatchesClassifier(i3, classifier.copy(), (BasicClassificationPerformanceEvaluator) basicClassificationPerformanceEvaluator.copy(), this.instancesSeen, this.disableBackgroundLearnerOption.isSet(), this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, false);
                    break;
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return getClass() == StreamingRandomPatches.class ? new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE) : new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }

    @Override // moa.classifiers.AbstractClassifier, moa.learners.Learner
    public Classifier[] getSublearners() {
        Classifier[] classifierArr = new Classifier[this.ensemble.length];
        for (int i = 0; i < classifierArr.length; i++) {
            classifierArr[i] = this.ensemble[i].classifier;
        }
        return classifierArr;
    }

    private static ArrayList<ArrayList<Integer>> localRandomKCombinations(int i, int i2, int i3, Random random) {
        ArrayList<ArrayList<Integer>> arrayList = new ArrayList<>();
        for (int i4 = 0; i4 < i3; i4++) {
            ArrayList<Integer> arrayList2 = new ArrayList<>();
            for (int i5 = 0; i5 < i2; i5++) {
                arrayList2.add(Integer.valueOf(i5));
            }
            for (int i6 = 0; i6 < i2 - i; i6++) {
                arrayList2.remove(random.nextInt(arrayList2.size()));
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    private static void allKCombinationsInner(int i, int i2, ArrayList<Integer> arrayList, long j, ArrayList<ArrayList<Integer>> arrayList2) {
        if (i2 == 0) {
            arrayList2.add(new ArrayList<>(arrayList));
            return;
        }
        for (int i3 = i; i3 <= j - i2; i3++) {
            arrayList.add(Integer.valueOf(i3));
            allKCombinationsInner(i3 + 1, i2 - 1, arrayList, j, arrayList2);
            arrayList.remove(arrayList.size() - 1);
        }
    }

    private static ArrayList<ArrayList<Integer>> allKCombinations(int i, int i2) {
        ArrayList<ArrayList<Integer>> arrayList = new ArrayList<>();
        allKCombinationsInner(0, i, new ArrayList(), i2, arrayList);
        return arrayList;
    }
}
