package moa.classifiers.rules.functions;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.util.Iterator;
import java.util.LinkedList;
import moa.classifiers.AbstractClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/rules/functions/Perceptron.class */
public class Perceptron extends AbstractClassifier implements AMRulesRegressorFunction {
    private final double SD_THRESHOLD = 1.0E-7d;
    private static final long serialVersionUID = 1;
    public FlagOption constantLearningRatioDecayOption;
    public FloatOption learningRatioOption;
    public FloatOption learningRateDecayOption;
    public FloatOption fadingFactorOption;
    public IntOption randomSeedOption;
    private double nError;
    protected double fadingFactor;
    protected double learningRatio;
    protected double learningRateDecay;
    protected double[] weightAttribute;
    public DoubleVector perceptronattributeStatistics;
    public DoubleVector squaredperceptronattributeStatistics;
    protected double perceptronInstancesSeen;
    protected double perceptronYSeen;
    protected double accumulatedError;
    protected boolean initialisePerceptron;
    protected double perceptronsumY;
    protected double squaredperceptronsumY;
    protected int[] numericAttributesIndex;

    public Perceptron() {
        this.SD_THRESHOLD = 1.0E-7d;
        this.constantLearningRatioDecayOption = new FlagOption("learningRatio_Decay_set_constant", 'd', "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
        this.learningRatioOption = new FloatOption("learningRatio", 'l', "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025d);
        this.learningRateDecayOption = new FloatOption("learningRateDecay", 'm', " Learning Rate decay to use for training the Perceptron.", 0.001d);
        this.fadingFactorOption = new FloatOption("fadingFactor", 'e', "Fading factor for the Perceptron accumulated error", 0.99d, 0.0d, 1.0d);
        this.randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the classifier.", 1);
        this.perceptronattributeStatistics = new DoubleVector();
        this.squaredperceptronattributeStatistics = new DoubleVector();
        super.randomSeedOption = this.randomSeedOption;
        this.initialisePerceptron = true;
    }

    public Perceptron(Perceptron perceptron) {
        this.SD_THRESHOLD = 1.0E-7d;
        this.constantLearningRatioDecayOption = new FlagOption("learningRatio_Decay_set_constant", 'd', "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
        this.learningRatioOption = new FloatOption("learningRatio", 'l', "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025d);
        this.learningRateDecayOption = new FloatOption("learningRateDecay", 'm', " Learning Rate decay to use for training the Perceptron.", 0.001d);
        this.fadingFactorOption = new FloatOption("fadingFactor", 'e', "Fading factor for the Perceptron accumulated error", 0.99d, 0.0d, 1.0d);
        this.randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the classifier.", 1);
        this.perceptronattributeStatistics = new DoubleVector();
        this.squaredperceptronattributeStatistics = new DoubleVector();
        this.constantLearningRatioDecayOption = perceptron.constantLearningRatioDecayOption;
        this.learningRatioOption = perceptron.learningRatioOption;
        this.learningRateDecayOption = perceptron.learningRateDecayOption;
        this.fadingFactorOption = perceptron.fadingFactorOption;
        this.nError = perceptron.nError;
        this.fadingFactor = perceptron.fadingFactor;
        this.learningRatio = perceptron.learningRatio;
        this.learningRateDecay = perceptron.learningRateDecay;
        if (perceptron.weightAttribute != null) {
            this.weightAttribute = (double[]) perceptron.weightAttribute.clone();
        }
        this.perceptronattributeStatistics = new DoubleVector(perceptron.perceptronattributeStatistics);
        this.squaredperceptronattributeStatistics = new DoubleVector(perceptron.squaredperceptronattributeStatistics);
        this.perceptronInstancesSeen = perceptron.perceptronInstancesSeen;
        this.initialisePerceptron = perceptron.initialisePerceptron;
        this.perceptronsumY = perceptron.perceptronsumY;
        this.squaredperceptronsumY = perceptron.squaredperceptronsumY;
        this.perceptronYSeen = perceptron.perceptronYSeen;
        this.numericAttributesIndex = (int[]) perceptron.numericAttributesIndex.clone();
        this.randomSeed = perceptron.randomSeed;
    }

    public void setWeights(double[] dArr) {
        this.weightAttribute = dArr;
    }

    public double[] getWeights() {
        return this.weightAttribute;
    }

    public double getInstancesSeen() {
        return this.perceptronInstancesSeen;
    }

    public void setInstancesSeen(int i) {
        this.perceptronInstancesSeen = i;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.initialisePerceptron = true;
        reset();
    }

    public void reset() {
        this.classifierRandom.setSeed(this.randomSeed);
        this.nError = 0.0d;
        this.accumulatedError = 0.0d;
        this.perceptronInstancesSeen = 0.0d;
        this.perceptronattributeStatistics = new DoubleVector();
        this.squaredperceptronattributeStatistics = new DoubleVector();
        this.perceptronsumY = 0.0d;
        this.squaredperceptronsumY = 0.0d;
        this.perceptronYSeen = 0.0d;
    }

    public void resetError() {
        this.nError = 0.0d;
        this.accumulatedError = 0.0d;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        this.accumulatedError = (Math.abs(prediction(instance) - instance.classValue()) * instance.weight()) + (this.fadingFactor * this.accumulatedError);
        this.nError = instance.weight() + (this.fadingFactor * this.nError);
        if (this.initialisePerceptron) {
            LinkedList linkedList = new LinkedList();
            for (int i = 0; i < instance.numAttributes(); i++) {
                if (instance.attribute(i).isNumeric() && i != instance.classIndex()) {
                    linkedList.add(Integer.valueOf(i));
                }
            }
            this.numericAttributesIndex = new int[linkedList.size()];
            int i2 = 0;
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.numericAttributesIndex[i3] = ((Integer) it.next()).intValue();
            }
            this.fadingFactor = this.fadingFactorOption.getValue();
            this.initialisePerceptron = false;
            this.weightAttribute = new double[this.numericAttributesIndex.length + 1];
            for (int i4 = 0; i4 < this.numericAttributesIndex.length + 1; i4++) {
                this.weightAttribute[i4] = (2.0d * this.classifierRandom.nextDouble()) - 1.0d;
            }
            this.learningRatio = this.learningRatioOption.getValue();
            this.learningRateDecay = this.learningRateDecayOption.getValue();
        }
        this.perceptronInstancesSeen += instance.weight();
        this.perceptronYSeen += instance.weight();
        for (int i5 = 0; i5 < this.numericAttributesIndex.length; i5++) {
            double value = instance.value(modelAttIndexToInstanceAttIndex(this.numericAttributesIndex[i5], instance));
            this.perceptronattributeStatistics.addToValue(i5, value * instance.weight());
            this.squaredperceptronattributeStatistics.addToValue(i5, value * value * instance.weight());
        }
        double classValue = instance.classValue();
        this.perceptronsumY += classValue * instance.weight();
        this.squaredperceptronsumY += classValue * classValue * instance.weight();
        if (!this.constantLearningRatioDecayOption.isSet()) {
            this.learningRatio = this.learningRatioOption.getValue() / (1.0d + (this.perceptronInstancesSeen * this.learningRateDecay));
        }
        updateWeights(instance, this.learningRatio);
    }

    private double prediction(Instance instance) {
        if (this.initialisePerceptron) {
            return 0.0d;
        }
        return denormalizedPrediction(prediction(normalizedInstance(instance)));
    }

    public double normalizedPrediction(Instance instance) {
        return prediction(normalizedInstance(instance));
    }

    private double denormalizedPrediction(double d) {
        if (this.initialisePerceptron) {
            return d;
        }
        double d2 = this.perceptronsumY / this.perceptronYSeen;
        double computeSD = computeSD(this.squaredperceptronsumY, this.perceptronsumY, this.perceptronYSeen);
        return computeSD > 1.0E-7d ? (d * computeSD) + d2 : d + d2;
    }

    public double prediction(double[] dArr) {
        double d = 0.0d;
        if (!this.initialisePerceptron) {
            for (int i = 0; i < dArr.length - 1; i++) {
                d += this.weightAttribute[i] * dArr[i];
            }
            d += this.weightAttribute[dArr.length - 1];
        }
        return d;
    }

    public double[] normalizedInstance(Instance instance) {
        double[] dArr = new double[this.numericAttributesIndex.length + 1];
        for (int i = 0; i < this.numericAttributesIndex.length; i++) {
            int modelAttIndexToInstanceAttIndex = modelAttIndexToInstanceAttIndex(this.numericAttributesIndex[i], instance);
            double value = this.perceptronattributeStatistics.getValue(i) / this.perceptronYSeen;
            double computeSD = computeSD(this.squaredperceptronattributeStatistics.getValue(i), this.perceptronattributeStatistics.getValue(i), this.perceptronYSeen);
            if (computeSD > 1.0E-7d) {
                dArr[i] = (instance.value(modelAttIndexToInstanceAttIndex) - value) / computeSD;
            } else {
                dArr[i] = instance.value(modelAttIndexToInstanceAttIndex) - value;
            }
        }
        return dArr;
    }

    public double computeSD(double d, double d2, double d3) {
        if (d3 > 1.0d) {
            return Math.sqrt((d - ((d2 * d2) / d3)) / (d3 - 1.0d));
        }
        return 0.0d;
    }

    public void updateWeights(Instance instance, double d) {
        double[] normalizedInstance = normalizedInstance(instance);
        double d2 = 0.0d;
        double normalizeActualClassValue = normalizeActualClassValue(instance) - prediction(normalizedInstance);
        for (int i = 0; i < this.numericAttributesIndex.length; i++) {
            double[] dArr = this.weightAttribute;
            int i2 = i;
            dArr[i2] = dArr[i2] + (d * normalizeActualClassValue * normalizedInstance[i] * instance.weight());
            d2 += Math.abs(this.weightAttribute[i]);
        }
        double[] dArr2 = this.weightAttribute;
        int length = this.numericAttributesIndex.length;
        dArr2[length] = dArr2[length] + (d * normalizeActualClassValue * instance.weight());
        double abs = d2 + Math.abs(this.weightAttribute[this.numericAttributesIndex.length]);
        if (abs > this.numericAttributesIndex.length) {
            for (int i3 = 0; i3 < this.numericAttributesIndex.length; i3++) {
                this.weightAttribute[i3] = this.weightAttribute[i3] / abs;
            }
            this.weightAttribute[this.numericAttributesIndex.length] = this.weightAttribute[this.numericAttributesIndex.length] / abs;
        }
    }

    public void normalizeWeights() {
        double d = 0.0d;
        for (int i = 0; i < this.weightAttribute.length; i++) {
            d += Math.abs(this.weightAttribute[i]);
        }
        for (int i2 = 0; i2 < this.weightAttribute.length; i2++) {
            this.weightAttribute[i2] = this.weightAttribute[i2] / d;
        }
    }

    private double normalizeActualClassValue(Instance instance) {
        double d = this.perceptronsumY / this.perceptronYSeen;
        double computeSD = computeSD(this.squaredperceptronsumY, this.perceptronsumY, this.perceptronYSeen);
        return computeSD > 1.0E-7d ? (instance.classValue() - d) / computeSD : instance.classValue() - d;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        return !this.initialisePerceptron ? new double[]{prediction(instance)} : new double[]{0.0d};
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        if (this.weightAttribute != null) {
            for (int i2 = 0; i2 < this.weightAttribute.length - 1; i2++) {
                if (this.weightAttribute[i2] > 0.0d && i2 > 0) {
                    sb.append(" +" + (Math.round(this.weightAttribute[i2] * 1000.0d) / 1000.0d) + " X" + i2);
                } else if (this.weightAttribute[i2] < 0.0d || i2 == 0) {
                    sb.append(" " + (Math.round(this.weightAttribute[i2] * 1000.0d) / 1000.0d) + " X" + i2);
                }
            }
            if (this.weightAttribute[this.weightAttribute.length - 1] >= 0.0d) {
                sb.append(" +" + (Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000.0d) / 1000.0d));
            } else {
                sb.append(" " + (Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000.0d) / 1000.0d));
            }
        }
    }

    public void setLearningRatio(double d) {
        this.learningRatio = d;
    }

    @Override // moa.classifiers.rules.functions.AMRulesLearner
    public double getCurrentError() {
        if (this.nError > 0.0d) {
            return this.accumulatedError / this.nError;
        }
        return Double.MAX_VALUE;
    }
}
