package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.gui.knowledgeflow.KnowledgeFlowApp;

/* loaded from: input_file:weka-stable-3.8.4.jar:weka/classifiers/functions/SGD.class */
public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler, Aggregateable<SGD> {
    private static final long serialVersionUID = -3732968666673530290L;
    protected ReplaceMissingValues m_replaceMissing;
    protected Filter m_nominalToBinary;
    protected Normalize m_normalize;
    protected double[] m_weights;
    protected double m_t;
    protected double m_numInstances;
    protected Instances m_data;
    public static final int HINGE = 0;
    public static final int LOGLOSS = 1;
    public static final int SQUAREDLOSS = 2;
    public static final int EPSILON_INSENSITIVE = 3;
    public static final int HUBER = 4;
    public static final Tag[] TAGS_SELECTION = {new Tag(0, "Hinge loss (SVM)"), new Tag(1, "Log loss (logistic regression)"), new Tag(2, "Squared loss (regression)"), new Tag(3, "Epsilon-insensitive loss (SVM regression)"), new Tag(4, "Huber loss (robust regression)")};
    protected double m_lambda = 1.0E-4d;
    protected double m_learningRate = 0.01d;
    protected double m_epsilon = 0.001d;
    protected int m_epochs = 500;
    protected boolean m_dontNormalize = false;
    protected boolean m_dontReplaceMissing = false;
    protected int m_loss = 0;
    protected int m_numModels = 0;

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        if (this.m_loss == 2 || this.m_loss == 3 || this.m_loss == 4) {
            capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        } else {
            capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        }
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    public String epsilonTipText() {
        return "The epsilon threshold for epsilon insensitive and Huber loss. An error with absolute value less that this threshold has loss of 0 for epsilon insensitive loss. For Huber loss this is the boundary between the quadratic and linear parts of the loss function.";
    }

    public void setEpsilon(double d) {
        this.m_epsilon = d;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    public void setLambda(double d) {
        this.m_lambda = d;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLearningRate(double d) {
        this.m_learningRate = d;
    }

    public double getLearningRate() {
        return this.m_learningRate;
    }

    public String learningRateTipText() {
        return "The learning rate. If normalization is turned off (as it is automatically for streaming data), thenthe default learning rate will need to be reduced (try 0.0001).";
    }

    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). The total number of iterations is epochs * num instances.";
    }

    public void setEpochs(int i) {
        this.m_epochs = i;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setDontNormalize(boolean z) {
        this.m_dontNormalize = z;
    }

    public boolean getDontNormalize() {
        return this.m_dontNormalize;
    }

    public String dontNormalizeTipText() {
        return "Turn normalization off";
    }

    public void setDontReplaceMissing(boolean z) {
        this.m_dontReplaceMissing = z;
    }

    public boolean getDontReplaceMissing() {
        return this.m_dontReplaceMissing;
    }

    public String dontReplaceMissingTipText() {
        return "Turn off global replacement of missing values";
    }

    public void setLossFunction(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_SELECTION) {
            this.m_loss = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getLossFunction() {
        return new SelectedTag(this.m_loss, TAGS_SELECTION);
    }

    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM), log loss (logistic regression) or squared loss (regression).";
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.add(new Option("\tSet the loss function to minimize.\n\t0 = hinge loss (SVM), 1 = log loss (logistic regression),\n\t2 = squared loss (regression), 3 = epsilon insensitive loss (regression),\n\t4 = Huber loss (regression).\n\t(default = 0)", "F", 1, "-F"));
        vector.add(new Option("\tThe learning rate. If normalization is\n\tturned off (as it is automatically for streaming data), then the\n\tdefault learning rate will need to be reduced (try 0.0001).\n\t(default = 0.01).", "L", 1, "-L"));
        vector.add(new Option("\tThe lambda regularization constant (default = 0.0001)", "R", 1, "-R <double>"));
        vector.add(new Option("\tThe number of epochs to perform (batch learning only, default = 500)", "E", 1, "-E <integer>"));
        vector.add(new Option("\tThe epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)", "C", 1, "-C <double>"));
        vector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
        vector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        reset();
        super.setOptions(strArr);
        String option = Utils.getOption('F', strArr);
        if (option.length() != 0) {
            setLossFunction(new SelectedTag(Integer.parseInt(option), TAGS_SELECTION));
        }
        String option2 = Utils.getOption('R', strArr);
        if (option2.length() > 0) {
            setLambda(Double.parseDouble(option2));
        }
        String option3 = Utils.getOption('L', strArr);
        if (option3.length() > 0) {
            setLearningRate(Double.parseDouble(option3));
        }
        String option4 = Utils.getOption("E", strArr);
        if (option4.length() > 0) {
            setEpochs(Integer.parseInt(option4));
        }
        String option5 = Utils.getOption("C", strArr);
        if (option5.length() > 0) {
            setEpsilon(Double.parseDouble(option5));
        }
        setDontNormalize(Utils.getFlag("N", strArr));
        setDontReplaceMissing(Utils.getFlag('M', strArr));
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-F");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLossFunction().getSelectedTag().getID());
        arrayList.add("-L");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLearningRate());
        arrayList.add("-R");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLambda());
        arrayList.add("-E");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getEpochs());
        arrayList.add("-C");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getEpsilon());
        if (getDontNormalize()) {
            arrayList.add("-N");
        }
        if (getDontReplaceMissing()) {
            arrayList.add("-M");
        }
        Collections.addAll(arrayList, super.getOptions());
        return (String[]) arrayList.toArray(new String[1]);
    }

    public String globalInfo() {
        return "Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression, squared loss, Huber loss and epsilon-insensitive loss linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.\nFor numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate.";
    }

    public void reset() {
        this.m_t = 1.0d;
        this.m_weights = null;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        reset();
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numInstances() > 0 && !this.m_dontReplaceMissing) {
            this.m_replaceMissing = new ReplaceMissingValues();
            this.m_replaceMissing.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_replaceMissing);
        }
        boolean z = true;
        int i = 0;
        while (true) {
            if (i < instances2.numAttributes()) {
                if (i != instances2.classIndex() && !instances2.attribute(i).isNumeric()) {
                    z = false;
                    break;
                }
                i++;
            } else {
                break;
            }
        }
        if (!z) {
            if (instances2.numInstances() > 0) {
                this.m_nominalToBinary = new NominalToBinary();
            } else {
                this.m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary();
            }
            this.m_nominalToBinary.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_nominalToBinary);
        }
        if (!this.m_dontNormalize && instances2.numInstances() > 0) {
            this.m_normalize = new Normalize();
            this.m_normalize.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_normalize);
        }
        this.m_numInstances = instances2.numInstances();
        this.m_weights = new double[instances2.numAttributes() + 1];
        this.m_data = new Instances(instances2, 0);
        if (instances2.numInstances() > 0) {
            instances2.randomize(new Random(getSeed()));
            train(instances2);
        }
    }

    protected double dloss(double d) {
        if (this.m_loss == 0) {
            if (d < 1.0d) {
                return 1.0d;
            }
            return KStarConstants.FLOOR;
        }
        if (this.m_loss == 1) {
            if (d < KStarConstants.FLOOR) {
                return 1.0d / (Math.exp(d) + 1.0d);
            }
            double exp = Math.exp(-d);
            return exp / (exp + 1.0d);
        }
        if (this.m_loss != 3) {
            if (this.m_loss == 4 && Math.abs(d) > this.m_epsilon) {
                return d > KStarConstants.FLOOR ? this.m_epsilon : -this.m_epsilon;
            }
            return d;
        }
        if (d > this.m_epsilon) {
            return 1.0d;
        }
        if ((-d) > this.m_epsilon) {
            return -1.0d;
        }
        return KStarConstants.FLOOR;
    }

    private void train(Instances instances) throws Exception {
        for (int i = 0; i < this.m_epochs; i++) {
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                updateClassifier(instances.instance(i2), false);
            }
        }
    }

    protected static double dotProd(Instance instance, double[] dArr, int i) {
        double d = 0.0d;
        int numValues = instance.numValues();
        int length = dArr.length - 1;
        int i2 = 0;
        int i3 = 0;
        while (i2 < numValues && i3 < length) {
            int index = instance.index(i2);
            int i4 = i3;
            if (index == i4) {
                if (index != i && !instance.isMissingSparse(i2)) {
                    d += instance.valueSparse(i2) * dArr[i3];
                }
                i2++;
                i3++;
            } else if (index > i4) {
                i3++;
            } else {
                i2++;
            }
        }
        return d;
    }

    protected void updateClassifier(Instance instance, boolean z) throws Exception {
        double classValue;
        double d;
        if (instance.classIsMissing()) {
            return;
        }
        if (z) {
            if (this.m_replaceMissing != null) {
                this.m_replaceMissing.input(instance);
                instance = this.m_replaceMissing.output();
            }
            if (this.m_nominalToBinary != null) {
                this.m_nominalToBinary.input(instance);
                instance = this.m_nominalToBinary.output();
            }
            if (this.m_normalize != null) {
                this.m_normalize.input(instance);
                instance = this.m_normalize.output();
            }
        }
        double dotProd = dotProd(instance, this.m_weights, instance.classIndex());
        if (instance.classAttribute().isNominal()) {
            d = instance.classValue() == KStarConstants.FLOOR ? -1.0d : 1.0d;
            classValue = d * (dotProd + this.m_weights[this.m_weights.length - 1]);
        } else {
            classValue = instance.classValue() - (dotProd + this.m_weights[this.m_weights.length - 1]);
            d = 1.0d;
        }
        double d2 = this.m_numInstances == KStarConstants.FLOOR ? 1.0d - ((this.m_learningRate * this.m_lambda) / this.m_t) : 1.0d - ((this.m_learningRate * this.m_lambda) / this.m_numInstances);
        for (int i = 0; i < this.m_weights.length - 1; i++) {
            double[] dArr = this.m_weights;
            int i2 = i;
            dArr[i2] = dArr[i2] * d2;
        }
        if (this.m_loss == 2 || this.m_loss == 1 || this.m_loss == 4 || ((this.m_loss == 0 && classValue < 1.0d) || (this.m_loss == 3 && Math.abs(classValue) > this.m_epsilon))) {
            double dloss = this.m_learningRate * d * dloss(classValue);
            int numValues = instance.numValues();
            for (int i3 = 0; i3 < numValues; i3++) {
                int index = instance.index(i3);
                if (index != instance.classIndex() && !instance.isMissingSparse(i3)) {
                    double[] dArr2 = this.m_weights;
                    dArr2[index] = dArr2[index] + (dloss * instance.valueSparse(i3));
                }
            }
            double[] dArr3 = this.m_weights;
            int length = this.m_weights.length - 1;
            dArr3[length] = dArr3[length] + dloss;
        }
        this.m_t += 1.0d;
    }

    @Override // weka.classifiers.UpdateableClassifier
    public void updateClassifier(Instance instance) throws Exception {
        updateClassifier(instance, true);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = instance.classAttribute().isNominal() ? new double[2] : new double[1];
        if (this.m_replaceMissing != null) {
            this.m_replaceMissing.input(instance);
            instance = this.m_replaceMissing.output();
        }
        if (this.m_nominalToBinary != null) {
            this.m_nominalToBinary.input(instance);
            instance = this.m_nominalToBinary.output();
        }
        if (this.m_normalize != null) {
            this.m_normalize.input(instance);
            instance = this.m_normalize.output();
        }
        double dotProd = dotProd(instance, this.m_weights, instance.classIndex()) + this.m_weights[this.m_weights.length - 1];
        if (instance.classAttribute().isNumeric()) {
            dArr[0] = dotProd;
            return dArr;
        }
        if (dotProd <= KStarConstants.FLOOR) {
            if (this.m_loss == 1) {
                dArr[0] = 1.0d / (1.0d + Math.exp(dotProd));
                dArr[1] = 1.0d - dArr[0];
            } else {
                dArr[0] = 1.0d;
            }
        } else if (this.m_loss == 1) {
            dArr[1] = 1.0d / (1.0d + Math.exp(-dotProd));
            dArr[0] = 1.0d - dArr[1];
        } else {
            dArr[1] = 1.0d;
        }
        return dArr;
    }

    public double[] getWeights() {
        return this.m_weights;
    }

    public String toString() {
        if (this.m_weights == null) {
            return "SGD: No model built yet.\n";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Loss function: ");
        if (this.m_loss == 0) {
            stringBuffer.append("Hinge loss (SVM)\n\n");
        } else if (this.m_loss == 1) {
            stringBuffer.append("Log loss (logistic regression)\n\n");
        } else if (this.m_loss == 3) {
            stringBuffer.append("Epsilon insensitive loss (SVM regression)\n\n");
        } else if (this.m_loss == 4) {
            stringBuffer.append("Huber loss (robust regression)\n\n");
        } else {
            stringBuffer.append("Squared loss (linear regression)\n\n");
        }
        stringBuffer.append(this.m_data.classAttribute().name() + " = \n\n");
        int i = 0;
        for (int i2 = 0; i2 < this.m_weights.length - 1; i2++) {
            if (i2 != this.m_data.classIndex()) {
                if (i > 0) {
                    stringBuffer.append(" + ");
                } else {
                    stringBuffer.append("   ");
                }
                stringBuffer.append(Utils.doubleToString(this.m_weights[i2], 12, 4) + TestInstances.DEFAULT_SEPARATORS + (this.m_normalize != null ? "(normalized) " : KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF) + this.m_data.attribute(i2).name() + "\n");
                i++;
            }
        }
        if (this.m_weights[this.m_weights.length - 1] > KStarConstants.FLOOR) {
            stringBuffer.append(" + " + Utils.doubleToString(this.m_weights[this.m_weights.length - 1], 12, 4));
        } else {
            stringBuffer.append(" - " + Utils.doubleToString(-this.m_weights[this.m_weights.length - 1], 12, 4));
        }
        return stringBuffer.toString();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 14788 $");
    }

    @Override // weka.core.Aggregateable
    public SGD aggregate(SGD sgd) throws Exception {
        if (this.m_weights == null) {
            throw new Exception("No model built yet, can't aggregate");
        }
        if (!this.m_data.equalHeaders(sgd.m_data)) {
            throw new Exception("Can't aggregate - data headers dont match: " + this.m_data.equalHeadersMsg(sgd.m_data));
        }
        if (this.m_weights.length != sgd.getWeights().length) {
            throw new Exception("Can't aggregate - SDG to aggregate has weight vector that differs in length from ours.");
        }
        for (int i = 0; i < this.m_weights.length; i++) {
            double[] dArr = this.m_weights;
            int i2 = i;
            dArr[i2] = dArr[i2] + sgd.getWeights()[i];
        }
        this.m_numModels++;
        return this;
    }

    @Override // weka.core.Aggregateable
    public void finalizeAggregation() throws Exception {
        if (this.m_numModels == 0) {
            throw new Exception("Unable to finalize aggregation - haven't seen any models to aggregate");
        }
        for (int i = 0; i < this.m_weights.length; i++) {
            double[] dArr = this.m_weights;
            int i2 = i;
            dArr[i2] = dArr[i2] / (this.m_numModels + 1);
        }
        this.m_numModels = 0;
    }

    public static void main(String[] strArr) {
        runClassifier(new SGD(), strArr);
    }
}
