package moa.classifiers.deeplearning;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.driftdetection.ADWIN;
import moa.core.Measurement;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/deeplearning/MLP.class */
public class MLP extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public static final int OPTIMIZER_SGD = 0;
    public static final int OPTIMIZER_RMSPROP = 1;
    public static final int OPTIMIZER_RMSPROP_RESET = 2;
    public static final int OPTIMIZER_ADAGRAD = 3;
    public static final int OPTIMIZER_ADAGRAD_RESET = 4;
    public static final int OPTIMIZER_ADAM = 5;
    public static final int OPTIMIZER_ADAM_RESET = 6;
    public static final int deviceTypeOptionGPU = 0;
    public static final int deviceTypeOptionCPU = 1;
    public ADWIN lossEstimator;
    public String modelName;
    private int numberOfClasses;
    private double[] votes;
    private int gpuCount;
    private static final DecimalFormat decimalFormat = new DecimalFormat("0.00000");
    protected long samplesSeen = 0;
    protected long trainedCount = 0;
    protected NormalizeInfo[] normalizeInfo = null;
    private double[] pFeatureValues = null;
    private double[] pClassValue = null;
    public FloatOption learningRateOption = new FloatOption("learningRate", 'r', "Learning Rate", 0.03d, 1.0E-7d, 1.0d);
    public FloatOption backPropLossThreshold = new FloatOption("backPropLossThreshold", 'b', "Back propagation loss threshold", 0.0d, 0.0d, Math.pow(10.0d, 10.0d));
    public MultiChoiceOption optimizerTypeOption = new MultiChoiceOption("optimizer", 'o', "Choose optimizer", new String[]{"SGD", "RMSPROP", "RMSPROP_RESET", "ADAGRAD", "ADAGRAD_RESET", "ADAM", "ADAM_RESET"}, new String[]{"oSGD", "oRMSPROP", "oRMSPROP_RESET", "oADAGRAD", "oADAGRAD_RESET", "oADAM", "oADAM_RESET"}, 0);
    public FlagOption useOneHotEncode = new FlagOption("useOneHotEncode", 'h', "use one hot encoding");
    public FlagOption useNormalization = new FlagOption("useNormalization", 'n', "Normalize data");
    public IntOption numberOfNeuronsInEachLayerInLog2 = new IntOption("numberOfNeuronsInEachLayerInLog2", 'N', "Number of neurons in the each layer in log2", 10, 0, 20);
    public IntOption numberOfLayers = new IntOption("numberOfLayers", 'L', "Number of layers", 1, 1, 4);
    public IntOption miniBatchSize = new IntOption("miniBatchSize", 'B', "Mini Batch Size", 1, 1, 2048);
    public MultiChoiceOption deviceTypeOption = new MultiChoiceOption("deviceType", 'd', "Choose device to run the model(use CPU if GPUs are not available)", new String[]{"GPU", "CPU"}, new String[]{"GPU (use CPU if not available)", "CPU"}, 1);
    public IntOption djlRandomSeed = new IntOption("djlRandomSeed", 'S', "Random seed for DJL Engine", 10, 0, Integer.MAX_VALUE);
    public double deltaForADWIN = 1.0E-5d;
    protected Model nnmodel = null;
    protected Trainer trainer = null;
    protected int featureValuesArraySize = 0;
    private MiniBatch miniBatch = null;

    /* loaded from: input_file:lib/moa.jar:moa/classifiers/deeplearning/MLP$NormalizeInfo.class */
    public static class NormalizeInfo {
        double sumOfValues = 0.0d;
        double sumOfSquares = 0.0d;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "NN: special.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
    }

    public void trainOnMiniBatch(MiniBatch miniBatch, boolean z) {
        NDList nDList = miniBatch.d;
        NDList nDList2 = miniBatch.l;
        try {
            this.samplesSeen++;
            GradientCollector newGradientCollector = this.trainer.newGradientCollector();
            NDList forward = this.trainer.forward(nDList, nDList2);
            NDArray evaluate = this.trainer.getLoss().evaluate(nDList2, forward);
            double d = evaluate.getFloat(new long[0]);
            if (z && d > this.backPropLossThreshold.getValue()) {
                this.trainedCount++;
                try {
                    newGradientCollector.backward(evaluate);
                    this.trainer.step();
                } catch (IllegalStateException e) {
                }
            }
            this.lossEstimator.setInput(d);
            newGradientCollector.close();
            forward.close();
            evaluate.close();
        } catch (Exception e2) {
            System.err.println(e2);
            e2.printStackTrace();
            System.exit(1);
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        initializeNetwork(instance);
        if (this.miniBatch == null) {
            this.miniBatch = new MiniBatch(this.nnmodel.getNDManager().getDevice(), this.miniBatchSize.getValue());
        }
        if (this.useNormalization.isSet() || this.useOneHotEncode.isSet()) {
            this.pClassValue[0] = instance.classValue();
            this.miniBatch.addToMiniBatch(this.pFeatureValues, this.pClassValue);
        } else {
            this.miniBatch.addToMiniBatch(instance);
        }
        if (this.miniBatch.miniBatchFull()) {
            trainOnMiniBatch(this.miniBatch, true);
            this.miniBatch.discardMiniBatch();
            this.miniBatch = null;
        }
    }

    public double[] getVotesForFeatureValues(Instance instance, double[] dArr) {
        initializeNetwork(instance);
        try {
            NDManager newBaseManager = NDManager.newBaseManager(this.nnmodel.getNDManager().getDevice());
            NDList nDList = new NDList(newBaseManager.create(dArr).toType(DataType.FLOAT32, false));
            NDList evaluate = this.trainer.evaluate(nDList);
            for (int i = 0; i < instance.numClasses(); i++) {
                this.votes[i] = evaluate.get(0).toFloatArray()[i];
            }
            evaluate.close();
            nDList.close();
            newBaseManager.close();
        } catch (Exception e) {
            System.err.println(e);
            e.printStackTrace();
            System.exit(1);
        }
        return this.votes;
    }

    public double[] getVotesForFeatureValues(Instance instance) {
        initializeNetwork(instance);
        try {
            NDManager newBaseManager = NDManager.newBaseManager(this.nnmodel.getNDManager().getDevice());
            double[] doubleArray = instance.toDoubleArray();
            NDList nDList = new NDList(newBaseManager.create(doubleArray).toType(DataType.FLOAT32, false).get("0:" + Integer.toString(doubleArray.length - 1), new Object[0]));
            NDList evaluate = this.trainer.evaluate(nDList);
            for (int i = 0; i < instance.numClasses(); i++) {
                this.votes[i] = evaluate.get(0).toFloatArray()[i];
            }
            evaluate.close();
            nDList.close();
            newBaseManager.close();
        } catch (Exception e) {
            System.err.println(e);
            e.printStackTrace();
            System.exit(1);
        }
        return this.votes;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        initializeNetwork(instance);
        if (!this.useNormalization.isSet() && !this.useOneHotEncode.isSet()) {
            return getVotesForFeatureValues(instance);
        }
        setFeatureValuesArray(instance, this.pFeatureValues, this.useOneHotEncode.isSet(), true, this.normalizeInfo, this.samplesSeen);
        return getVotesForFeatureValues(instance, this.pFeatureValues);
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE);
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    public static int getFeatureValuesArraySize(Instance instance, boolean z) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < instance.numInputAttributes(); i3++) {
            if (z && instance.attribute(i3).isNominal() && instance.attribute(i3).numValues() > 2) {
                i += instance.attribute(i3).numValues();
                i2++;
            }
        }
        return (instance.numInputAttributes() + i) - i2;
    }

    public static double getNormalizedValue(double d, double d2, double d3, long j) {
        double d4 = 0.0d;
        double d5 = 0.0d;
        if (j > 1) {
            d5 = d2 / j;
            d4 = (d3 - ((d2 * d2) / j)) / j;
        }
        double sqrt = Math.sqrt(d4);
        if (sqrt > 0.0d) {
            return (d - d5) / (3.0d * sqrt);
        }
        return 0.0d;
    }

    public static void setFeatureValuesArray(Instance instance, double[] dArr, boolean z, boolean z2, NormalizeInfo[] normalizeInfoArr, long j) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < instance.numInputAttributes(); i3++) {
            int i4 = (i3 + i) - i2;
            if (z && instance.attribute(i3).isNominal() && instance.attribute(i3).numValues() > 2) {
                dArr[i4 + ((int) instance.value(i3))] = 1.0d;
                i += instance.attribute(i3).numValues();
                i2++;
            } else if (!instance.attribute(i3).isNumeric() || normalizeInfoArr == null || normalizeInfoArr[i4] == null) {
                dArr[i4] = instance.value(i3);
            } else {
                if (z2) {
                    normalizeInfoArr[i4].sumOfSquares += instance.value(i3) * instance.value(i3);
                    normalizeInfoArr[i4].sumOfValues += instance.value(i3);
                }
                dArr[i4] = getNormalizedValue(instance.value(i3), normalizeInfoArr[i4].sumOfValues, normalizeInfoArr[i4].sumOfSquares, j);
            }
        }
    }

    public void initializeNetwork(Instance instance) {
        if (this.nnmodel != null) {
            return;
        }
        Iterator<String> it = Engine.getAllEngines().iterator();
        while (it.hasNext()) {
            Engine.getEngine(it.next()).setRandomSeed(this.djlRandomSeed.getValue());
        }
        this.votes = new double[instance.numClasses()];
        if (this.useNormalization.isSet() || this.useOneHotEncode.isSet()) {
            this.pClassValue = new double[1];
            this.featureValuesArraySize = getFeatureValuesArraySize(instance, this.useOneHotEncode.isSet());
            this.pFeatureValues = new double[this.featureValuesArraySize];
            if (this.useNormalization.isSet()) {
                this.normalizeInfo = new NormalizeInfo[this.featureValuesArraySize];
                for (int i = 0; i < this.normalizeInfo.length; i++) {
                    this.normalizeInfo[i] = new NormalizeInfo();
                }
            }
        } else {
            this.featureValuesArraySize = instance.numInputAttributes();
        }
        try {
            this.gpuCount = Device.getGpuCount();
            this.numberOfClasses = instance.numClasses();
            setModel();
            this.lossEstimator = new ADWIN(this.deltaForADWIN);
            switch (this.optimizerTypeOption.getChosenIndex()) {
            }
            setTrainer();
        } catch (Exception e) {
            System.err.println(e);
            e.printStackTrace();
        }
        this.modelName = "L" + this.numberOfLayers.getValue() + "_N" + this.numberOfNeuronsInEachLayerInLog2.getValue() + "_" + this.optimizerTypeOption.getChosenLabel() + "_" + decimalFormat.format(this.learningRateOption.getValue());
    }

    public double getLossEstimation() {
        return this.lossEstimator.getEstimation();
    }

    protected void setModel() {
        try {
            if (this.deviceTypeOption.getChosenIndex() == 0 && this.gpuCount == 0) {
                throw new RuntimeException("GPU selected as device. But NO GPUs detected.");
            }
            this.nnmodel = Model.newInstance("mlp", this.deviceTypeOption.getChosenIndex() == 0 ? Device.gpu() : Device.cpu());
            Integer[] numArr = new Integer[0];
            ArrayList arrayList = new ArrayList(Arrays.asList(numArr));
            for (int i = 0; i < this.numberOfLayers.getValue(); i++) {
                arrayList.add(Integer.valueOf((int) Math.pow(2.0d, this.numberOfNeuronsInEachLayerInLog2.getValue())));
            }
            this.nnmodel.setBlock(new Mlp(this.featureValuesArraySize, this.numberOfClasses, Arrays.stream((Integer[]) arrayList.toArray(numArr)).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray()));
            System.out.println("System GPU count: " + this.gpuCount + " Model using Device: " + this.nnmodel.getNDManager().getDevice());
        } catch (Exception e) {
            System.err.println(e);
            e.printStackTrace();
        }
    }

    protected void setTrainer() {
        Optimizer build;
        if (this.trainer != null) {
            this.trainer.close();
            this.trainer = null;
        }
        try {
            Tracker fixed = Tracker.fixed((float) this.learningRateOption.getValue());
            switch (this.optimizerTypeOption.getChosenIndex()) {
                case 0:
                default:
                    build = Optimizer.sgd().setLearningRateTracker(fixed).build();
                    break;
                case 1:
                case 2:
                    build = Optimizer.rmsprop().optLearningRateTracker(fixed).build();
                    break;
                case 3:
                case 4:
                    build = Optimizer.adagrad().optLearningRateTracker(fixed).build();
                    break;
                case 5:
                case 6:
                    build = Optimizer.adam().optLearningRateTracker(fixed).build();
                    break;
            }
            DefaultTrainingConfig defaultTrainingConfig = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss());
            defaultTrainingConfig.optOptimizer(build);
            this.trainer = this.nnmodel.newTrainer(defaultTrainingConfig);
            this.trainer.initialize(new Shape(this.miniBatchSize.getValue(), this.featureValuesArraySize));
        } catch (Exception e) {
            System.err.println(e);
            e.printStackTrace();
        }
    }
}
