package moa.classifiers.deeplearning;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.github.javacliparser.StringOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import moa.capabilities.CapabilitiesHandler;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.classifiers.deeplearning.MLP;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.evaluation.BasicClassificationPerformanceEvaluator;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/deeplearning/CAND.class */
public class CAND extends AbstractClassifier implements MultiClassClassifier, CapabilitiesHandler {
    private static final long serialVersionUID = 1;
    public static final int LARGER_P_POOL_10 = 0;
    public static final int LARGER_P_POOL_30 = 1;
    protected MLP[] nn = null;
    protected int featureValuesArraySize = 0;
    protected long samplesSeen = 0;
    protected MLP.NormalizeInfo[] normalizeInfo = null;
    private double[] featureValues = null;
    private double[] class_value = null;
    private ExecutorService exService = null;
    private FileWriter statsDumpFile = null;
    private FileWriter votesDumpFile = null;
    private BasicClassificationPerformanceEvaluator performanceEvaluator = new BasicClassificationPerformanceEvaluator();
    private long driftsDetectedPerSampleFrequency = 0;
    private long totalDriftsDetected = 0;
    private long avgMLPsPerSampleFrequency = 0;
    private long lastGetModelMeasurementsImplCalledAt = 0;
    private MiniBatch miniBatch = null;
    private ADWIN accEstimator = new ADWIN(0.001d);
    public MultiChoiceOption largerPool = new MultiChoiceOption("largerPool", 'P', "The larger pool type", new String[]{"P10", "P30"}, new String[]{"P10 = { learning rates: 5.0E-(1 to 5), optimizes: SGD,Adam, neurons in 1st layer:  2^(8 to 10) }", "P30 = { learning rates: 5.0E-(1 to 5), optimizes: Adam, neurons in 1st layer:  2^9 }"}, 1);
    public IntOption numberOfMLPsToTrainOption = new IntOption("numberOfMLPsToTrain", 'o', "Number of MLPs to train at a given time (after numberOfInstancesToTrainAllMLPsAtStart instances)", 10, 2, Integer.MAX_VALUE);
    public IntOption numberOfLayersInEachMLP = new IntOption("numberOfLayersInEachMLP", 'L', "Number of layers in each MLP", 1, 1, 4);
    public IntOption numberOfInstancesToTrainAllMLPsAtStartOption = new IntOption("numberOfInstancesToTrainAllMLPsAtStart", 's', "Number of instances to train all MLPs at start", 100, 0, Integer.MAX_VALUE);
    public IntOption miniBatchSize = new IntOption("miniBatchSize", 'B', "Mini Batch Size", 1, 1, 2048);
    public FlagOption useOneHotEncode = new FlagOption("useOneHotEncode", 'h', "use one hot encoding");
    public FlagOption useNormalization = new FlagOption("useNormalization", 'n', "Normalize data");
    public FloatOption backPropLossThreshold = new FloatOption("backPropLossThreshold", 'b', "Skip back propagation loss threshold", 0.3d, 0.0d, Math.pow(10.0d, 10.0d));
    public MultiChoiceOption deviceTypeOption = new MultiChoiceOption("deviceType", 'd', "Choose device to run the model(For GPU, needs CUDA installed on the system. Use CPU if GPUs are not available)", new String[]{"GPU", "CPU"}, new String[]{"GPU (Needs CUDA installed on the system. Use CPU if not available)", "CPU"}, 1);
    public FlagOption doNotTrainEachMLPUsingASeparateThread = new FlagOption("doNotTrainEachMLPUsingASeparateThread", 't', "Do NOT train each MLP using a separate thread");
    public StringOption votesDumpFileName = new StringOption("votesDumpFileName", 'f', "Votes dump file name", "");
    public StringOption statsDumpFileName = new StringOption("statsDumpFileName", 'F', "Stats dump file name", "");
    public IntOption djlRandomSeed = new IntOption("djlRandomSeed", 'S', "Random seed for DJL Engine", 10, 0, Integer.MAX_VALUE);

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        if (this.nn != null) {
            this.exService.shutdownNow();
            this.exService = null;
            for (int i = 0; i < this.nn.length; i++) {
                this.nn[i] = null;
            }
            this.nn = null;
            this.featureValuesArraySize = 0;
            this.samplesSeen = 0L;
            this.normalizeInfo = null;
            this.featureValues = null;
            this.class_value = null;
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.nn == null) {
            initNNs(instance);
        }
        this.class_value[0] = instance.classValue();
        if (this.miniBatch == null) {
            this.miniBatch = new MiniBatch(this.nn[0].nnmodel.getNDManager().getDevice(), this.miniBatchSize.getValue());
        }
        this.miniBatch.addToMiniBatch(this.featureValues, this.class_value);
        if (this.miniBatch.miniBatchFull()) {
            int value = this.numberOfMLPsToTrainOption.getValue();
            int i = value / 2;
            if (this.samplesSeen < this.numberOfInstancesToTrainAllMLPsAtStartOption.getValue()) {
                value = this.nn.length;
                i = this.nn.length;
            }
            this.avgMLPsPerSampleFrequency += value;
            Arrays.sort(this.nn, new Comparator<MLP>() { // from class: moa.classifiers.deeplearning.CAND.1
                @Override // java.util.Comparator
                public int compare(MLP mlp, MLP mlp2) {
                    return Double.compare(mlp.getLossEstimation(), mlp2.getLossEstimation());
                }
            });
            boolean[] zArr = new boolean[this.nn.length];
            for (int i2 = 0; i2 < value; i2++) {
                zArr[i2 < i ? i2 : i + ((int) ((this.samplesSeen + i2) % (this.nn.length - i)))] = true;
            }
            Future[] futureArr = new Future[this.nn.length];
            for (int i3 = 0; i3 < this.nn.length; i3++) {
                if (this.doNotTrainEachMLPUsingASeparateThread.isSet()) {
                    this.nn[i3].initializeNetwork(instance);
                    this.nn[i3].trainOnMiniBatch(this.miniBatch, zArr[i3]);
                } else {
                    futureArr[i3] = this.exService.submit(new Callable<Boolean>(this.nn[i3], this.miniBatch, zArr[i3]) { // from class: moa.classifiers.deeplearning.CAND.1TrainThread
                        private final MLP mlp;
                        private final MiniBatch miniBatch;
                        private final boolean trainNet;

                        {
                            this.mlp = r5;
                            this.miniBatch = r6;
                            this.trainNet = r7;
                        }

                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Boolean call() {
                            try {
                                this.mlp.trainOnMiniBatch(this.miniBatch, this.trainNet);
                            } catch (NullPointerException e) {
                                e.printStackTrace();
                                System.exit(1);
                            }
                            return Boolean.TRUE;
                        }
                    });
                }
            }
            if (!this.doNotTrainEachMLPUsingASeparateThread.isSet()) {
                int length = this.nn.length;
                while (length != 0) {
                    length = 0;
                    for (int i4 = 0; i4 < this.nn.length; i4++) {
                        try {
                            if (!((Boolean) futureArr[i4].get()).equals(Boolean.TRUE)) {
                                length++;
                            }
                        } catch (InterruptedException | ExecutionException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
            this.miniBatch.discardMiniBatch();
            this.miniBatch = null;
        }
    }

    private void printStats() {
        long j = this.samplesSeen - this.lastGetModelMeasurementsImplCalledAt;
        if (this.statsDumpFile != null) {
            for (int i = 0; i < this.nn.length; i++) {
                try {
                    this.statsDumpFile.write(this.samplesSeen + "," + this.nn[i].samplesSeen + "," + this.nn[i].trainedCount + "," + this.nn[i].modelName + "," + this.performanceEvaluator.getPerformanceMeasurements()[1].getValue() + "," + this.nn[i].lossEstimator.getEstimation() + "," + this.totalDriftsDetected + "," + j + "," + this.driftsDetectedPerSampleFrequency + "," + (this.avgMLPsPerSampleFrequency / j) + "\n");
                    this.statsDumpFile.flush();
                } catch (IOException e) {
                    System.out.println("An error occurred.");
                    e.printStackTrace();
                }
            }
        }
    }

    private void printVotes(Instance instance) {
        if (this.votesDumpFile != null) {
            for (int i = 0; i < this.nn.length; i++) {
                try {
                    this.votesDumpFile.write(this.samplesSeen + "," + this.nn[i].modelName + "," + this.nn[i].lossEstimator.getEstimation() + "," + instance.classValue() + "," + instance.classIndex() + "," + Arrays.toString(this.nn[i].getVotesForFeatureValues(instance, this.featureValues)) + "\n");
                } catch (IOException e) {
                    System.out.println("An error occurred.");
                    e.printStackTrace();
                }
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        int i = 0;
        this.samplesSeen++;
        if (this.nn == null) {
            initNNs(instance);
        } else {
            double d = Double.MAX_VALUE;
            for (int i2 = 0; i2 < this.nn.length; i2++) {
                if (this.nn[i2].getLossEstimation() < d) {
                    d = this.nn[i2].getLossEstimation();
                    i = i2;
                }
            }
        }
        MLP.setFeatureValuesArray(instance, this.featureValues, this.useOneHotEncode.isSet(), true, this.normalizeInfo, this.samplesSeen);
        double[] votesForFeatureValues = this.nn[i].getVotesForFeatureValues(instance, this.featureValues);
        this.performanceEvaluator.addResult(new InstanceExample(instance), votesForFeatureValues);
        double estimation = this.accEstimator.getEstimation();
        this.accEstimator.setInput(this.performanceEvaluator.getPerformanceMeasurements()[1].getValue());
        if (this.accEstimator.getChange() && this.accEstimator.getEstimation() < estimation) {
            this.totalDriftsDetected++;
            this.driftsDetectedPerSampleFrequency++;
        }
        printVotes(instance);
        return votesForFeatureValues;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        printStats();
        this.driftsDetectedPerSampleFrequency = 0L;
        this.avgMLPsPerSampleFrequency = 0L;
        this.lastGetModelMeasurementsImplCalledAt = this.samplesSeen;
        try {
            if (this.votesDumpFile != null) {
                this.votesDumpFile.flush();
            }
            return null;
        } catch (IOException e) {
            System.out.println("An error occurred.");
            e.printStackTrace();
            return null;
        }
    }

    protected void initNNs(Instance instance) {
        C1MLPConfigs[] c1MLPConfigsArr = new C1MLPConfigs[0];
        ArrayList arrayList = new ArrayList(Arrays.asList(c1MLPConfigsArr));
        float[] fArr = {10.0f, 100.0f, 1000.0f, 10000.0f, 100000.0f};
        for (float f : new float[]{5.0f}) {
            for (float f2 : fArr) {
                float f3 = f / f2;
                for (int i = 8; i < 11; i++) {
                    if (this.largerPool.getChosenIndex() == 1) {
                        arrayList.add(new Object(i, 0, f3, 0.001d) { // from class: moa.classifiers.deeplearning.CAND.1MLPConfigs
                            private final int numberOfNeuronsInL1InLog2;
                            private final int optimizerType;
                            private final float learningRate;
                            private final double deltaForADWIN;

                            {
                                this.numberOfNeuronsInL1InLog2 = i;
                                this.optimizerType = r7;
                                this.learningRate = f3;
                                this.deltaForADWIN = r9;
                            }
                        });
                    } else if (i == 9) {
                    }
                    arrayList.add(new Object(i, 5, f3, 0.001d) { // from class: moa.classifiers.deeplearning.CAND.1MLPConfigs
                        private final int numberOfNeuronsInL1InLog2;
                        private final int optimizerType;
                        private final float learningRate;
                        private final double deltaForADWIN;

                        {
                            this.numberOfNeuronsInL1InLog2 = i;
                            this.optimizerType = r7;
                            this.learningRate = f3;
                            this.deltaForADWIN = r9;
                        }
                    });
                }
            }
        }
        C1MLPConfigs[] c1MLPConfigsArr2 = (C1MLPConfigs[]) arrayList.toArray(c1MLPConfigsArr);
        this.nn = new MLP[c1MLPConfigsArr2.length];
        for (int i2 = 0; i2 < c1MLPConfigsArr2.length; i2++) {
            this.nn[i2] = new MLP();
            this.nn[i2].optimizerTypeOption.setChosenIndex(c1MLPConfigsArr2[i2].optimizerType);
            this.nn[i2].learningRateOption.setValue(c1MLPConfigsArr2[i2].learningRate);
            this.nn[i2].useOneHotEncode.setValue(this.useOneHotEncode.isSet());
            this.nn[i2].deviceTypeOption.setChosenIndex(this.deviceTypeOption.getChosenIndex());
            this.nn[i2].numberOfNeuronsInEachLayerInLog2.setValue(c1MLPConfigsArr2[i2].numberOfNeuronsInL1InLog2);
            this.nn[i2].numberOfLayers.setValue(this.numberOfLayersInEachMLP.getValue());
            this.nn[i2].deltaForADWIN = c1MLPConfigsArr2[i2].deltaForADWIN;
            this.nn[i2].backPropLossThreshold.setValue(this.backPropLossThreshold.getValue());
            this.nn[i2].djlRandomSeed.setValue(this.djlRandomSeed.getValue());
            this.nn[i2].initializeNetwork(instance);
        }
        try {
            if (this.statsDumpFileName.getValue().length() > 0) {
                this.statsDumpFile = new FileWriter(this.statsDumpFileName.getValue());
                this.statsDumpFile.write("id,samplesSeenAtTrain,trainedCount,optimizer_type_learning_rate_delta,acc,estimated_loss,totalDriftsDetected,sampleFrequency,driftsDetectedPerSampleFrequency,avgMLPsPerSampleFrequency\n");
                this.statsDumpFile.flush();
            }
            if (this.votesDumpFileName.getValue().length() > 0) {
                this.votesDumpFile = new FileWriter(this.votesDumpFileName.getValue());
                this.votesDumpFile.write("id,modelName,estimated_loss,classValue,classIndex,votes,\n");
                this.votesDumpFile.flush();
            }
        } catch (IOException e) {
            System.out.println("An error occurred.");
            e.printStackTrace();
        }
        this.exService = Executors.newFixedThreadPool(c1MLPConfigsArr2.length);
        this.class_value = new double[1];
        this.featureValuesArraySize = MLP.getFeatureValuesArraySize(instance, this.useOneHotEncode.isSet());
        System.out.println("Number of features before one-hot encode: " + instance.numInputAttributes() + " : Number of features after one-hot encode: " + this.featureValuesArraySize);
        this.featureValues = new double[this.featureValuesArraySize];
        if (this.useNormalization.isSet()) {
            this.normalizeInfo = new MLP.NormalizeInfo[this.featureValuesArraySize];
            for (int i3 = 0; i3 < this.normalizeInfo.length; i3++) {
                this.normalizeInfo[i3] = new MLP.NormalizeInfo();
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
    }
}
