package moa.classifiers.multilabel;

import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.HashMap;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiLabelLearner;
import moa.core.Measurement;
import moa.core.StringUtils;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/multilabel/MajorityLabelset.class */
public class MajorityLabelset extends AbstractMultiLabelLearner implements MultiLabelLearner {
    private static final long serialVersionUID = 1;
    private double maxValue = -1.0d;
    private MultiLabelPrediction majorityLabelset = null;
    private HashMap<String, Double> vectorCounts = new HashMap<>();

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Majority labelset classifier: always predicts the labelvector most frequently seen so far.";
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.majorityLabelset = null;
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        int numberOutputTargets = multiLabelInstance.numberOutputTargets();
        MultiLabelPrediction multiLabelPrediction = new MultiLabelPrediction(numberOutputTargets);
        for (int i = 0; i < numberOutputTargets; i++) {
            multiLabelPrediction.setVotes(i, new double[]{1.0d - multiLabelInstance.classValue(i), multiLabelInstance.classValue(i)});
        }
        double weight = multiLabelInstance.weight();
        if (this.vectorCounts.containsKey(multiLabelPrediction.toString())) {
            weight += this.vectorCounts.get(multiLabelPrediction.toString()).doubleValue();
        }
        this.vectorCounts.put(multiLabelPrediction.toString(), Double.valueOf(weight));
        if (weight > this.maxValue) {
            this.maxValue = weight;
            this.majorityLabelset = multiLabelPrediction;
        }
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        return this.majorityLabelset == null ? new MultiLabelPrediction(multiLabelInstance.numberOutputTargets()) : this.majorityLabelset;
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        StringUtils.appendIndented(sb, i, "");
        sb.append(this.majorityLabelset.toString());
        StringUtils.appendNewline(sb);
    }
}
