package moa.classifiers.bayes;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.util.Arrays;
import moa.capabilities.Capability;
import moa.capabilities.ImmutableCapabilities;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.core.Utils;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/bayes/NaiveBayesMultinomial.class */
public class NaiveBayesMultinomial extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = -7204398796974263187L;
    protected double[] m_classTotals;
    protected Instances m_headerInfo;
    protected int m_numClasses;
    protected double[] m_probOfClass;
    protected DoubleVector[] m_wordTotalForClass;
    public FloatOption laplaceCorrectionOption = new FloatOption("laplaceCorrection", 'l', "Laplace correction factor.", 1.0d, 0.0d, 2.147483647E9d);
    protected boolean reset = false;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Multinomial Naive Bayes classifier: performs classic bayesian prediction while making naive assumption that all inputs are independent.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.reset = true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.reset) {
            this.m_numClasses = instance.numClasses();
            double value = this.laplaceCorrectionOption.getValue();
            int numAttributes = instance.numAttributes();
            this.m_probOfClass = new double[this.m_numClasses];
            Arrays.fill(this.m_probOfClass, value);
            this.m_classTotals = new double[this.m_numClasses];
            Arrays.fill(this.m_classTotals, value * numAttributes);
            this.m_wordTotalForClass = new DoubleVector[this.m_numClasses];
            for (int i = 0; i < this.m_numClasses; i++) {
                this.m_wordTotalForClass[i] = new DoubleVector();
            }
            this.reset = false;
        }
        int classIndex = instance.classIndex();
        int classValue = (int) instance.classValue();
        double weight = instance.weight();
        double[] dArr = this.m_probOfClass;
        dArr[classValue] = dArr[classValue] + weight;
        double[] dArr2 = this.m_classTotals;
        dArr2[classValue] = dArr2[classValue] + (weight * totalSize(instance));
        double d = this.m_classTotals[classValue];
        for (int i2 = 0; i2 < instance.numValues(); i2++) {
            int index = instance.index(i2);
            if (index != classIndex && !instance.isMissing(i2)) {
                this.m_wordTotalForClass[classValue].addToValue(index, (weight * instance.valueSparse(i2)) + (this.m_wordTotalForClass[classValue].getValue(index) == 0.0d ? this.laplaceCorrectionOption.getValue() : 0.0d));
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        if (this.reset) {
            return new double[2];
        }
        double[] dArr = new double[this.m_numClasses];
        double d = totalSize(instance);
        for (int i = 0; i < this.m_numClasses; i++) {
            dArr[i] = Math.log(this.m_probOfClass[i]) - (d * Math.log(this.m_classTotals[i]));
        }
        for (int i2 = 0; i2 < instance.numValues(); i2++) {
            int index = instance.index(i2);
            if (index != instance.classIndex() && !instance.isMissing(i2)) {
                double valueSparse = instance.valueSparse(i2);
                for (int i3 = 0; i3 < this.m_numClasses; i3++) {
                    double value = this.m_wordTotalForClass[i3].getValue(index);
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + (valueSparse * Math.log(value == 0.0d ? this.laplaceCorrectionOption.getValue() : value));
                }
            }
        }
        return Utils.logs2probs(dArr);
    }

    public double totalSize(Instance instance) {
        int classIndex = instance.classIndex();
        double d = 0.0d;
        for (int i = 0; i < instance.numValues(); i++) {
            if (instance.index(i) != classIndex && !instance.isMissing(i)) {
                double valueSparse = instance.valueSparse(i);
                if (valueSparse >= 0.0d) {
                    d += valueSparse;
                }
            }
        }
        return d;
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        StringUtils.appendIndented(sb, i, "xxx MNB1 xxx\n\n");
        sb.append("The independent probability of a class\n");
        sb.append("--------------------------------------\n");
        for (int i2 = 0; i2 < this.m_numClasses; i2++) {
            sb.append(this.m_headerInfo.classAttribute().value(i2)).append("\t").append(Double.toString(this.m_probOfClass[i2])).append("\n");
        }
        sb.append("\nThe probability of a word given the class\n");
        sb.append("-----------------------------------------\n\t");
        for (int i3 = 0; i3 < this.m_numClasses; i3++) {
            sb.append(this.m_headerInfo.classAttribute().value(i3)).append("\t");
        }
        sb.append("\n");
        for (int i4 = 0; i4 < this.m_headerInfo.numAttributes(); i4++) {
            if (i4 != this.m_headerInfo.classIndex()) {
                sb.append(this.m_headerInfo.attribute(i4).name()).append("\t");
                for (int i5 = 0; i5 < this.m_numClasses; i5++) {
                    double value = this.m_wordTotalForClass[i5].getValue(i4);
                    if (value == 0.0d) {
                        value = this.laplaceCorrectionOption.getValue();
                    }
                    sb.append(value / this.m_classTotals[i5]).append("\t");
                }
                sb.append("\n");
            }
        }
        StringUtils.appendNewline(sb);
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.capabilities.CapabilitiesHandler
    public ImmutableCapabilities defineImmutableCapabilities() {
        return getClass() == NaiveBayesMultinomial.class ? new ImmutableCapabilities(Capability.VIEW_STANDARD, Capability.VIEW_LITE) : new ImmutableCapabilities(Capability.VIEW_STANDARD);
    }
}
