package moa.classifiers.functions;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.core.DoubleVector;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/functions/AdaGrad.class */
public class AdaGrad extends SGD {
    private static final long serialVersionUID = -3732968666673530291L;
    protected double m_epsilon = 1.0E-8d;
    public FloatOption epsilonOption = new FloatOption("epsilon", 'p', "epsilon parameter.", 1.0E-8d);
    protected DoubleVector m_velocity;
    protected double m_biasVelocity;

    @Override // moa.classifiers.functions.SGD, moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "An online optimiser for learning various linear models (binary class SVM, binary class logistic regression and linear regression).";
    }

    public void setEpsilon(double d) {
        this.m_epsilon = d;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public AdaGrad() {
        this.lambdaRegularizationOption = new FloatOption(this.lambdaRegularizationOption.getName(), this.lambdaRegularizationOption.getCLIChar(), this.lambdaRegularizationOption.getPurpose(), 0.0d);
        this.learningRateOption = new FloatOption(this.learningRateOption.getName(), this.learningRateOption.getCLIChar(), this.learningRateOption.getPurpose(), 0.01d);
    }

    @Override // moa.classifiers.functions.SGD, moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        reset();
        setLambda(this.lambdaRegularizationOption.getValue());
        setLearningRate(this.learningRateOption.getValue());
        setEpsilon(this.epsilonOption.getValue());
        setLossFunction(this.lossFunctionOption.getChosenIndex());
    }

    @Override // moa.classifiers.functions.SGD, moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        double classValue;
        if (this.m_weights == null) {
            this.m_weights = new DoubleVector();
            this.m_velocity = new DoubleVector();
            this.m_bias = 0.0d;
            this.m_weights.setValue(instance.numAttributes(), 0.0d);
            this.m_velocity.setValue(instance.numAttributes(), 0.0d);
        }
        if (instance.classIsMissing()) {
            return;
        }
        double dotProd = dotProd(instance, this.m_weights, instance.classIndex()) + this.m_bias;
        if (instance.classAttribute().isNominal()) {
            double d = instance.classValue() == 0.0d ? 0.0d : 1.0d;
            if (this.m_loss == 1) {
                classValue = (1.0d / (1.0d + Math.exp(-dotProd))) - d;
            } else {
                double d2 = (d * 2.0d) - 1.0d;
                classValue = d2 * dotProd < 1.0d ? -d2 : 0.0d;
            }
        } else {
            classValue = dotProd - instance.classValue();
        }
        int numValues = instance.numValues();
        DoubleVector doubleVector = new DoubleVector();
        doubleVector.setValue(instance.numAttributes(), 0.0d);
        for (int i = 0; i < numValues; i++) {
            int index = instance.index(i);
            doubleVector.setValue(index, (instance.valueSparse(i) * classValue) + ((this.m_lambda / (this.m_t + this.m_epsilon)) * this.m_weights.getValue(index)));
        }
        double d3 = classValue;
        this.m_biasVelocity += d3 * d3;
        this.m_bias -= (this.m_learningRate / (Math.sqrt(this.m_biasVelocity) + this.m_epsilon)) * d3;
        for (int i2 = 0; i2 < this.m_weights.numValues(); i2++) {
            double value = doubleVector.getValue(i2);
            this.m_velocity.addToValue(i2, value * value);
            this.m_weights.addToValue(i2, (-(this.m_learningRate / (Math.sqrt(this.m_velocity.getValue(i2)) + this.m_epsilon))) * value);
        }
        this.m_t += 1.0d;
    }

    @Override // moa.classifiers.functions.SGD
    protected String getModelName() {
        return "AdaGrad";
    }
}
