package moa.classifiers.rules.multilabel.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.Prediction;
import meka.classifiers.multilabel.Evaluation;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiTargetRegressor;
import moa.classifiers.rules.featureranking.FeatureRanking;
import moa.classifiers.rules.featureranking.NoFeatureRanking;
import moa.classifiers.rules.multilabel.AMRulesMultiLabelLearner;
import moa.classifiers.rules.multilabel.core.voting.ErrorWeightedVoteMultiLabel;
import moa.classifiers.rules.multilabel.core.voting.UniformWeightedVoteMultiLabel;
import moa.classifiers.rules.multilabel.errormeasurers.AbstractMultiTargetErrorMeasurer;
import moa.classifiers.rules.multilabel.errormeasurers.MultiLabelErrorMeasurer;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/rules/multilabel/meta/MultiLabelRandomAMRules.class */
public class MultiLabelRandomAMRules extends AbstractMultiLabelLearner implements MultiTargetRegressor {
    private static final long serialVersionUID = 1;
    protected AMRulesMultiLabelLearner[] ensemble;
    protected MultiLabelErrorMeasurer[] errorMeasurer;
    protected boolean isRegression;
    protected FeatureRanking featureRanking;
    private int nAttributes = 0;
    public IntOption VerbosityOption = new IntOption(Evaluation.FLAG_VERBOSITY, 'v', "Output Verbosity Control Level. 1 (Less) to 2 (More)", 1, 1, 2);
    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", AMRulesMultiLabelLearner.class, "AMRulesMultiTargetRegressor");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
    public FloatOption numAttributesPercentageOption = new FloatOption("numAttributesPercentage", 'n', "The number of attributes to use per model.", 63.2d, 0.0d, 100.0d);
    public FlagOption useBaggingOption = new FlagOption("useBagging", 'p', "Use Bagging.");
    public ClassOption votingFunctionOption = new ClassOption("votingFunction", 'V', "Voting Function.", ErrorWeightedVoteMultiLabel.class, UniformWeightedVoteMultiLabel.class.getName());
    public MultiChoiceOption votingTypeOption = new MultiChoiceOption("votingTypeOption", 'C', "Select whether the base learner error is computed as the overall error or only the error of the rules that cover the example.", new String[]{"Overall (Static)", "Only rules covered (Dynamic)"}, new String[]{"Overall", "Covered"}, 0);
    public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the classifier.", 1);
    public ClassOption errorMeasurerOption = new ClassOption("errorMeasurer", 'e', "Measure of error for deciding which learner should predict.", AbstractMultiTargetErrorMeasurer.class, "MeanAbsoluteDeviationMT");
    public ClassOption featureRankingOption = new ClassOption("featureRanking", 'F', "Feature ranking algorithm.", FeatureRanking.class, NoFeatureRanking.class.getName());

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.classifierRandom.setSeed(this.randomSeedOption.getValue());
        int value = this.ensembleSizeOption.getValue();
        this.ensemble = new AMRulesMultiLabelLearner[value];
        this.errorMeasurer = new MultiLabelErrorMeasurer[value];
        AMRulesMultiLabelLearner aMRulesMultiLabelLearner = (AMRulesMultiLabelLearner) getPreparedClassOption(this.baseLearnerOption);
        MultiLabelErrorMeasurer multiLabelErrorMeasurer = (MultiLabelErrorMeasurer) getPreparedClassOption(this.errorMeasurerOption);
        aMRulesMultiLabelLearner.setAttributesPercentage(this.numAttributesPercentageOption.getValue());
        aMRulesMultiLabelLearner.resetLearning();
        for (int i = 0; i < this.ensemble.length; i++) {
            this.ensemble[i] = (AMRulesMultiLabelLearner) aMRulesMultiLabelLearner.copy();
            this.ensemble[i].setRandomSeed(this.classifierRandom.nextInt());
            this.errorMeasurer[i] = (MultiLabelErrorMeasurer) multiLabelErrorMeasurer.copy();
        }
        this.isRegression = aMRulesMultiLabelLearner instanceof MultiTargetRegressor;
        this.featureRanking = (FeatureRanking) getPreparedClassOption(this.featureRankingOption);
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        if (this.featureRanking == null) {
            this.featureRanking = (FeatureRanking) getPreparedClassOption(this.featureRankingOption);
            for (int i = 0; i < this.ensemble.length; i++) {
                this.ensemble[i].setObserver(this.featureRanking);
            }
            this.nAttributes = multiLabelInstance.numInputAttributes();
        }
        for (int i2 = 0; i2 < this.ensemble.length; i2++) {
            MultiLabelInstance multiLabelInstance2 = (MultiLabelInstance) multiLabelInstance.copy();
            int poisson = this.useBaggingOption.isSet() ? MiscUtils.poisson(1.0d, this.classifierRandom) : 1;
            if (poisson > 0) {
                multiLabelInstance2.setWeight(multiLabelInstance2.weight() * poisson);
                Prediction predictionForInstance = this.ensemble[i2].getPredictionForInstance(multiLabelInstance2);
                if (predictionForInstance != null) {
                    this.errorMeasurer[i2].addPrediction(predictionForInstance, multiLabelInstance2);
                }
                this.ensemble[i2].trainOnInstance(multiLabelInstance2);
            }
        }
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        ErrorWeightedVoteMultiLabel errorWeightedVoteMultiLabel = (ErrorWeightedVoteMultiLabel) ((ErrorWeightedVoteMultiLabel) getPreparedClassOption(this.votingFunctionOption)).copy();
        StringBuilder sb = this.VerbosityOption.getValue() > 1 ? new StringBuilder() : null;
        for (int i = 0; i < this.ensemble.length; i++) {
            ErrorWeightedVoteMultiLabel votes = this.ensemble[i].getVotes(multiLabelInstance);
            if (this.VerbosityOption.getValue() > 1) {
                sb.append(votes.getPrediction() + ",  E: " + votes.getWeightedError() + " ");
            }
            Prediction prediction = votes.getPrediction();
            if (prediction != null) {
                if (this.votingTypeOption.getChosenIndex() == 0) {
                    errorWeightedVoteMultiLabel.addVote(prediction, this.errorMeasurer[i].getCurrentErrors());
                } else {
                    errorWeightedVoteMultiLabel.addVote(prediction, votes.getOutputAttributesErrors());
                }
            }
        }
        Prediction computeWeightedVote = errorWeightedVoteMultiLabel.computeWeightedVote();
        if (this.VerbosityOption.getValue() > 1) {
            sb.append(computeWeightedVote + ", ").append(multiLabelInstance.classValue());
            System.out.println(sb.toString());
        }
        return computeWeightedVote;
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] modelMeasurements = this.ensemble[0].getModelMeasurements();
        int length = modelMeasurements.length;
        Measurement[] measurementArr = new Measurement[this.featureRanking instanceof NoFeatureRanking ? length + 1 : length + this.nAttributes + 1];
        int i = 0;
        if (this.ensemble != null) {
            i = this.ensemble.length;
            for (int i2 = 0; i2 < length; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < i; i3++) {
                    d += this.ensemble[i3].getModelMeasurements()[i2].getValue();
                }
                measurementArr[i2 + 1] = new Measurement("Avg " + modelMeasurements[i2].getName(), d / i);
            }
        } else {
            for (int i4 = 0; i4 < modelMeasurements.length; i4++) {
                measurementArr[i4 + 1] = modelMeasurements[i4];
            }
        }
        measurementArr[0] = new Measurement("ensemble size", i);
        if (!(this.featureRanking instanceof NoFeatureRanking)) {
            DoubleVector featureRankings = this.featureRanking.getFeatureRankings();
            for (int i5 = 0; i5 < this.nAttributes; i5++) {
                double d2 = 0.0d;
                if (featureRankings != null) {
                    d2 = featureRankings.getValue(i5);
                }
                measurementArr[i5 + length + 1] = new Measurement("Attribute" + i5, d2);
            }
        }
        return measurementArr;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }
}
