package moa.classifiers.functions;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.core.Utils;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/functions/SPegasos.class */
public class SPegasos extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = -3732968666673530290L;
    protected static final int HINGE = 0;
    protected static final int LOGLOSS = 1;
    protected double[] m_weights;
    protected double m_t;
    protected double m_lambda = 1.0E-4d;
    public FloatOption lambdaRegularizationOption = new FloatOption("lambdaRegularization", 'l', "Lambda regularization parameter .", 1.0E-4d, 0.0d, 2.147483647E9d);
    protected int m_loss = 0;
    public MultiChoiceOption lossFunctionOption = new MultiChoiceOption("lossFunction", 'o', "The loss function to use.", new String[]{"HINGE", "LOGLOSS"}, new String[]{"Hinge loss (SVM)", "Log loss (logistic regression)"}, 0);

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007).";
    }

    public void setLambda(double d) {
        this.m_lambda = d;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLossFunction(int i) {
        this.m_loss = i;
    }

    public int getLossFunction() {
        return this.m_loss;
    }

    public void reset() {
        this.m_t = 2.0d;
        this.m_weights = null;
    }

    protected static double dotProd(Instance instance, double[] dArr, int i) {
        double d = 0.0d;
        int numValues = instance.numValues();
        int length = dArr.length - 1;
        int i2 = 0;
        int i3 = 0;
        while (i2 < numValues && i3 < length) {
            int index = instance.index(i2);
            int i4 = i3;
            if (index == i4) {
                if (index != i && !instance.isMissingSparse(i2)) {
                    d += instance.valueSparse(i2) * dArr[i3];
                }
                i2++;
                i3++;
            } else if (index > i4) {
                i3++;
            } else {
                i2++;
            }
        }
        return d;
    }

    protected double dloss(double d) {
        if (this.m_loss == 0) {
            return d < 1.0d ? 1.0d : 0.0d;
        }
        if (d < 0.0d) {
            return 1.0d / (Math.exp(d) + 1.0d);
        }
        double exp = Math.exp(-d);
        return exp / (exp + 1.0d);
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        reset();
        setLambda(this.lambdaRegularizationOption.getValue());
        setLossFunction(this.lossFunctionOption.getChosenIndex());
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.m_weights == null) {
            this.m_weights = new double[instance.numAttributes() + 1];
        }
        if (instance.classIsMissing()) {
            return;
        }
        double d = 1.0d / (this.m_lambda * this.m_t);
        double d2 = 1.0d - (1.0d / this.m_t);
        double d3 = instance.classValue() == 0.0d ? -1.0d : 1.0d;
        double dotProd = d3 * (dotProd(instance, this.m_weights, instance.classIndex()) + this.m_weights[this.m_weights.length - 1]);
        for (int i = 0; i < this.m_weights.length - 1; i++) {
            if (i != instance.classIndex()) {
                double[] dArr = this.m_weights;
                int i2 = i;
                dArr[i2] = dArr[i2] * d2;
            }
        }
        if (this.m_loss == 1 || dotProd < 1.0d) {
            double dloss = dloss(dotProd);
            int numValues = instance.numValues();
            for (int i3 = 0; i3 < numValues; i3++) {
                int index = instance.index(i3);
                if (index != instance.classIndex() && !instance.isMissingSparse(i3)) {
                    double valueSparse = d * dloss * instance.valueSparse(i3) * d3;
                    double[] dArr2 = this.m_weights;
                    dArr2[index] = dArr2[index] + valueSparse;
                }
            }
            double[] dArr3 = this.m_weights;
            int length = this.m_weights.length - 1;
            dArr3[length] = dArr3[length] + (d * dloss * d3);
        }
        double d4 = 0.0d;
        for (int i4 = 0; i4 < this.m_weights.length - 1; i4++) {
            if (i4 != instance.classIndex()) {
                d4 += this.m_weights[i4] * this.m_weights[i4];
            }
        }
        double min = Math.min(1.0d, 1.0d / (this.m_lambda * d4));
        if (min < 1.0d) {
            double sqrt = Math.sqrt(min);
            for (int i5 = 0; i5 < this.m_weights.length - 1; i5++) {
                if (i5 != instance.classIndex()) {
                    double[] dArr4 = this.m_weights;
                    int i6 = i5;
                    dArr4[i6] = dArr4[i6] * sqrt;
                }
            }
        }
        this.m_t += 1.0d;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        if (this.m_weights == null) {
            return new double[instance.numAttributes() + 1];
        }
        double[] dArr = new double[2];
        double dotProd = dotProd(instance, this.m_weights, instance.classIndex()) + this.m_weights[this.m_weights.length - 1];
        if (dotProd <= 0.0d) {
            if (this.m_loss == 1) {
                dArr[0] = 1.0d / (1.0d + Math.exp(dotProd));
                dArr[1] = 1.0d - dArr[0];
            } else {
                dArr[0] = 1.0d;
            }
        } else if (this.m_loss == 1) {
            dArr[1] = 1.0d / (1.0d + Math.exp(-dotProd));
            dArr[0] = 1.0d - dArr[1];
        } else {
            dArr[1] = 1.0d;
        }
        return dArr;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        StringUtils.appendIndented(sb, i, toString());
        StringUtils.appendNewline(sb);
    }

    @Override // moa.AbstractMOAObject
    public String toString() {
        if (this.m_weights == null) {
            return "SPegasos: No model built yet.\n";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Loss function: ");
        if (this.m_loss == 0) {
            stringBuffer.append("Hinge loss (SVM)\n\n");
        } else {
            stringBuffer.append("Log loss (logistic regression)\n\n");
        }
        int i = 0;
        for (int i2 = 0; i2 < this.m_weights.length - 1; i2++) {
            if (i > 0) {
                stringBuffer.append(" + ");
            } else {
                stringBuffer.append("   ");
            }
            stringBuffer.append(Utils.doubleToString(this.m_weights[i2], 12, 4) + " \n");
            i++;
        }
        if (this.m_weights[this.m_weights.length - 1] > 0.0d) {
            stringBuffer.append(" + " + Utils.doubleToString(this.m_weights[this.m_weights.length - 1], 12, 4));
        } else {
            stringBuffer.append(" - " + Utils.doubleToString(-this.m_weights[this.m_weights.length - 1], 12, 4));
        }
        return stringBuffer.toString();
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }
}
