package moa.classifiers.rules.multilabel;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Iterator;
import java.util.ListIterator;
import meka.classifiers.multilabel.Evaluation;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiLabelLearner;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.classifiers.rules.core.anomalydetection.AnomalyDetector;
import moa.classifiers.rules.core.anomalydetection.OddsRatioScore;
import moa.classifiers.rules.featureranking.FeatureRanking;
import moa.classifiers.rules.featureranking.NoFeatureRanking;
import moa.classifiers.rules.featureranking.messages.ChangeDetectedMessage;
import moa.classifiers.rules.multilabel.attributeclassobservers.NominalStatisticsObserver;
import moa.classifiers.rules.multilabel.attributeclassobservers.NumericStatisticsObserver;
import moa.classifiers.rules.multilabel.core.MultiLabelRule;
import moa.classifiers.rules.multilabel.core.MultiLabelRuleSet;
import moa.classifiers.rules.multilabel.core.ObserverMOAObject;
import moa.classifiers.rules.multilabel.core.splitcriteria.MultiLabelSplitCriterion;
import moa.classifiers.rules.multilabel.core.voting.ErrorWeightedVoteMultiLabel;
import moa.classifiers.rules.multilabel.errormeasurers.MultiLabelErrorMeasurer;
import moa.classifiers.rules.multilabel.functions.MultiLabelPerceptronClassification;
import moa.classifiers.rules.multilabel.inputselectors.InputAttributesSelector;
import moa.classifiers.rules.multilabel.inputselectors.SelectAllInputs;
import moa.classifiers.rules.multilabel.instancetransformers.NoInstanceTransformation;
import moa.classifiers.rules.multilabel.outputselectors.OutputAttributesSelector;
import moa.classifiers.rules.multilabel.outputselectors.SelectAllOutputs;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.options.ClassOption;
import org.apache.log4j.Priority;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/rules/multilabel/AMRulesMultiLabelLearner.class */
public abstract class AMRulesMultiLabelLearner extends AbstractMultiLabelLearner implements MultiLabelLearner {
    private static final long serialVersionUID = 1;
    protected MultiLabelRuleSet ruleSet;
    protected MultiLabelRule defaultRule;
    protected int ruleNumberID;
    protected double[] statistics;
    protected ObserverMOAObject observer;
    public FloatOption splitConfidenceOption;
    public FloatOption tieThresholdOption;
    public IntOption gracePeriodOption;
    public ClassOption learnerOption;
    public FlagOption unorderedRulesOption;
    public FlagOption dropOldRuleAfterExpansionOption;
    public ClassOption changeDetector;
    public ClassOption anomalyDetector;
    public ClassOption splitCriterionOption;
    public ClassOption errorMeasurerOption;
    public ClassOption weightedVoteOption;
    public ClassOption numericObserverOption;
    public ClassOption nominalObserverOption;
    public IntOption VerbosityOption;
    public ClassOption outputSelectorOption;
    public ClassOption inputSelectorOption;
    public IntOption randomSeedOption;
    public ClassOption featureRankingOption;
    private int nAttributes;
    protected double attributesPercentage;
    private double numChangesDetected;
    private double numAnomaliesDetected;
    private double numInstances;
    private FeatureRanking featureRanking;

    public double getAttributesPercentage() {
        return this.attributesPercentage;
    }

    public void setAttributesPercentage(double d) {
        this.attributesPercentage = d;
    }

    public AMRulesMultiLabelLearner() {
        this.ruleNumberID = 1;
        this.splitConfidenceOption = new FloatOption("splitConfidence", 'c', "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", 1.0E-7d, 0.0d, 1.0d);
        this.tieThresholdOption = new FloatOption("tieThreshold", 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", 0.05d, 0.0d, 1.0d);
        this.gracePeriodOption = new IntOption("gracePeriod", 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", 200, 1, Integer.MAX_VALUE);
        this.unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', "unorderedRules.");
        this.dropOldRuleAfterExpansionOption = new FlagOption("dropOldRuleAfterExpansion", 'D', "Drop old rule if it expanded (by default the rule is kept for the set of outputs not selected for expansion.)");
        this.changeDetector = new ClassOption("changeDetector", 'H', "Change Detector.", ChangeDetector.class, "PageHinkleyDM -d 0.05 -l 35.0");
        this.anomalyDetector = new ClassOption("anomalyDetector", 'A', "Anomaly Detector.", AnomalyDetector.class, OddsRatioScore.class.getName());
        this.numericObserverOption = new ClassOption("numericObserver", 'y', "Numeric observer.", NumericStatisticsObserver.class, "MultiLabelBSTree");
        this.nominalObserverOption = new ClassOption("nominalObserver", 'z', "Nominal observer.", NominalStatisticsObserver.class, "MultiLabelNominalAttributeObserver");
        this.VerbosityOption = new IntOption(Evaluation.FLAG_VERBOSITY, 'v', "Output Verbosity Control Level. 1 (Less) to 5 (More)", 1, 1, 5);
        this.outputSelectorOption = new ClassOption("outputSelector", 'O', "Output attributes selector", OutputAttributesSelector.class, SelectAllOutputs.class.getName());
        this.inputSelectorOption = new ClassOption("inputSelector", 'I', "Input attributes selector", InputAttributesSelector.class, SelectAllInputs.class.getName());
        this.randomSeedOption = new IntOption("randomSeedOption", 'r', "randomSeedOption", 1, Priority.ALL_INT, Integer.MAX_VALUE);
        this.featureRankingOption = new ClassOption("featureRanking", 'F', "Feature ranking algorithm.", FeatureRanking.class, NoFeatureRanking.class.getName());
        this.nAttributes = 0;
        this.randomSeedOption = this.randomSeedOption;
        this.attributesPercentage = 100.0d;
    }

    public AMRulesMultiLabelLearner(double d) {
        this();
        this.attributesPercentage = d;
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        ErrorWeightedVoteMultiLabel votes = getVotes(multiLabelInstance);
        Prediction prediction = votes.getPrediction();
        if (votes == null) {
            return null;
        }
        if (((MultiLabelLearner) getPreparedClassOption(this.learnerOption)) instanceof MultiLabelPerceptronClassification) {
            for (int i = 0; i < prediction.size(); i++) {
                prediction.setVote(i, 0, prediction.getVote(i, 0) < 0.5d ? 1.0d : 0.0d);
            }
        }
        return prediction;
    }

    public ErrorWeightedVoteMultiLabel getVotes(MultiLabelInstance multiLabelInstance) {
        ErrorWeightedVoteMultiLabel newErrorWeightedVote = newErrorWeightedVote();
        VerboseToConsole(multiLabelInstance);
        Iterator it = this.ruleSet.iterator();
        while (it.hasNext()) {
            MultiLabelRule multiLabelRule = (MultiLabelRule) it.next();
            if (multiLabelRule.isCovering(multiLabelInstance)) {
                Prediction predictionForInstance = multiLabelRule.getPredictionForInstance(multiLabelInstance);
                if (predictionForInstance != null) {
                    double[] currentErrors = multiLabelRule.getCurrentErrors();
                    if (currentErrors == null) {
                        currentErrors = defaultRuleErrors(predictionForInstance);
                    }
                    debug("Rule No" + multiLabelRule.getRuleNumberID() + " Vote: " + predictionForInstance.toString() + " Error: " + currentErrors + " Y: " + multiLabelInstance.classValue(), 3);
                    newErrorWeightedVote.addVote(predictionForInstance, currentErrors);
                }
                if (!this.unorderedRulesOption.isSet()) {
                    break;
                }
            }
        }
        if (!newErrorWeightedVote.coversAllOutputs()) {
            Prediction prediction = newErrorWeightedVote.getPrediction();
            if (prediction == null) {
                prediction = new MultiLabelPrediction(multiLabelInstance.numberOutputTargets());
            }
            Prediction predictionForInstance2 = this.defaultRule.getPredictionForInstance(multiLabelInstance);
            if (predictionForInstance2 != null) {
                double[] currentErrors2 = this.defaultRule.getCurrentErrors();
                if (currentErrors2 == null) {
                    currentErrors2 = defaultRuleErrors(predictionForInstance2);
                }
                double[] dArr = new double[prediction.numOutputAttributes()];
                MultiLabelPrediction multiLabelPrediction = new MultiLabelPrediction(prediction.numOutputAttributes());
                for (int i = 0; i < prediction.numOutputAttributes(); i++) {
                    if (!prediction.hasVotesForAttribute(i)) {
                        multiLabelPrediction.setVotes(i, predictionForInstance2.getVotes(i));
                        dArr[i] = currentErrors2[i];
                    }
                }
                newErrorWeightedVote.addVote(multiLabelPrediction, dArr);
                debug("Default Rule Vote " + predictionForInstance2.toString() + "\n Error " + currentErrors2 + "  Y: " + multiLabelInstance, 3);
            }
        }
        newErrorWeightedVote.computeWeightedVote();
        return newErrorWeightedVote;
    }

    protected double[] defaultRuleErrors(Prediction prediction) {
        double[] dArr = new double[prediction.numOutputAttributes()];
        for (int i = 0; i < prediction.numOutputAttributes(); i++) {
            if (prediction.hasVotesForAttribute(i)) {
                dArr[i] = Double.MAX_VALUE;
            }
        }
        return dArr;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        if (this.nAttributes == 0) {
            this.nAttributes = multiLabelInstance.numInputAttributes();
        }
        this.numInstances += multiLabelInstance.weight();
        debug("Train", 3);
        debug("Nº instance " + this.numInstances + " - " + multiLabelInstance.toString(), 3);
        boolean z = false;
        ListIterator listIterator = this.ruleSet.listIterator();
        while (listIterator.hasNext()) {
            MultiLabelRule multiLabelRule = (MultiLabelRule) listIterator.next();
            if (multiLabelRule.isCovering(multiLabelInstance)) {
                z = true;
                if (multiLabelRule.updateAnomalyDetection(multiLabelInstance)) {
                    debug("Anomaly Detected: " + this.numInstances + " Rule: " + multiLabelRule.getRuleNumberID(), 1);
                    this.numAnomaliesDetected += multiLabelInstance.weight();
                } else if (multiLabelRule.updateChangeDetection(multiLabelInstance)) {
                    debug("I) Drift Detected. Exa. : " + this.numInstances + " (" + multiLabelRule.getWeightSeenSinceExpansion() + ") Remove Rule: " + multiLabelRule.getRuleNumberID(), 1);
                    listIterator.remove();
                    multiLabelRule.notifyAll(new ChangeDetectedMessage());
                    this.numChangesDetected += multiLabelInstance.weight();
                } else {
                    multiLabelRule.trainOnInstance(multiLabelInstance);
                    if (multiLabelRule.getWeightSeenSinceExpansion() % this.gracePeriodOption.getValue() == 0.0d && multiLabelRule.tryToExpand(this.splitConfidenceOption.getValue(), this.tieThresholdOption.getValue())) {
                        MultiLabelRule newRuleFromOtherOutputs = multiLabelRule.getNewRuleFromOtherOutputs();
                        if (!this.dropOldRuleAfterExpansionOption.isSet() && multiLabelRule.hasNewRuleFromOtherOutputs()) {
                            multiLabelRule.clearOtherOutputs();
                            int i = this.ruleNumberID + 1;
                            this.ruleNumberID = i;
                            newRuleFromOtherOutputs.setRuleNumberID(i);
                            setRuleOptions(newRuleFromOtherOutputs);
                            listIterator.add(newRuleFromOtherOutputs);
                            if (this.observer != null) {
                                newRuleFromOtherOutputs.addObserver(this.observer);
                            }
                        }
                        setRuleOptions(multiLabelRule);
                        debug("Rule Expanded:", 2);
                        debug(multiLabelRule.toString(), 2);
                    }
                }
                if (!this.unorderedRulesOption.isSet()) {
                    break;
                }
            }
        }
        if (z) {
            return;
        }
        this.defaultRule.trainOnInstance(multiLabelInstance);
        if (this.defaultRule.getWeightSeenSinceExpansion() % this.gracePeriodOption.getValue() == 0.0d) {
            debug("Nr. examples " + this.defaultRule.getWeightSeenSinceExpansion(), 4);
            if (this.defaultRule.tryToExpand(this.splitConfidenceOption.getValue(), this.tieThresholdOption.getValue())) {
                MultiLabelRule newRuleFromOtherBranch = this.defaultRule.getNewRuleFromOtherBranch();
                int i2 = this.ruleNumberID + 1;
                this.ruleNumberID = i2;
                newRuleFromOtherBranch.setRuleNumberID(i2);
                setRuleOptions(newRuleFromOtherBranch);
                setRuleOptions(this.defaultRule);
                this.ruleSet.add(this.defaultRule);
                debug("Default rule expanded! New Rule:", 2);
                debug(this.defaultRule.toString(), 2);
                debug("New default rule:", 3);
                debug(newRuleFromOtherBranch.toString(), 3);
                this.defaultRule = newRuleFromOtherBranch;
                if (this.observer != null) {
                    this.defaultRule.addObserver(this.observer);
                }
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurementArr;
        Measurement[] measurementArr2 = {new Measurement("anomaly detections", this.numAnomaliesDetected), new Measurement("change detections", this.numChangesDetected), new Measurement("rules (number)", this.ruleSet.size() + 1), new Measurement("Avg #inputs/rule", getAverageInputs()), new Measurement("Avg #outputs/rule", getAverageOutputs())};
        if (this.featureRanking instanceof NoFeatureRanking) {
            measurementArr = measurementArr2;
        } else {
            measurementArr = new Measurement[measurementArr2.length + this.nAttributes];
            for (int i = 0; i < measurementArr2.length; i++) {
                measurementArr[i] = measurementArr2[i];
            }
            DoubleVector featureRankings = this.featureRanking.getFeatureRankings();
            for (int i2 = 0; i2 < this.nAttributes; i2++) {
                measurementArr[i2 + measurementArr2.length] = new Measurement("Attribute" + i2, featureRankings.getValue(i2));
            }
        }
        return measurementArr;
    }

    protected double getAverageInputs() {
        double d = 0.0d;
        int i = 0;
        if (this.ruleSet.size() > 0) {
            Iterator it = this.ruleSet.iterator();
            while (it.hasNext()) {
                if (((MultiLabelRule) it.next()).getInputsCovered() != null) {
                    d += r0.length;
                    i++;
                }
            }
        }
        if (this.defaultRule.getInputsCovered() != null) {
            d += r0.length;
            i++;
        }
        if (i > 0) {
            d /= i;
        }
        return d;
    }

    protected double getAverageOutputs() {
        double d = 0.0d;
        int i = 0;
        if (this.ruleSet.size() > 0) {
            Iterator it = this.ruleSet.iterator();
            while (it.hasNext()) {
                if (((MultiLabelRule) it.next()).getOutputsCovered() != null) {
                    d += r0.length;
                    i++;
                }
            }
        }
        if (this.defaultRule.getOutputsCovered() != null) {
            d += r0.length;
            i++;
        }
        if (i > 0) {
            d /= i;
        }
        return d;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
        if (this.unorderedRulesOption.isSet()) {
            StringUtils.appendIndented(sb, i, "Method Unordered");
            StringUtils.appendNewline(sb);
        } else {
            StringUtils.appendIndented(sb, i, "Method Ordered");
            StringUtils.appendNewline(sb);
        }
        StringUtils.appendIndented(sb, i, "Number of Rules: " + (this.ruleSet.size() + 1));
        StringUtils.appendNewline(sb);
        StringUtils.appendIndented(sb, i, "Default rule :");
        this.defaultRule.getDescription(sb, i);
        StringUtils.appendNewline(sb);
        StringUtils.appendIndented(sb, i, "Rules in ruleSet:");
        StringUtils.appendNewline(sb);
        Iterator it = this.ruleSet.iterator();
        while (it.hasNext()) {
            ((MultiLabelRule) it.next()).getDescription(sb, i);
            StringUtils.appendNewline(sb);
        }
    }

    protected void debug(String str, int i) {
        if (this.VerbosityOption.getValue() >= i) {
            System.out.println(str);
        }
    }

    protected void VerboseToConsole(MultiLabelInstance multiLabelInstance) {
        if (this.VerbosityOption.getValue() >= 5) {
            System.out.println();
            System.out.println("I) Dataset: " + multiLabelInstance.dataset().getRelationName());
            if (this.unorderedRulesOption.isSet()) {
                System.out.println("I) Method Unordered");
            } else {
                System.out.println("I) Method Ordered");
            }
        }
    }

    public void PrintRuleSet() {
        debug("Default rule :", 2);
        debug(this.defaultRule.toString(), 2);
        debug("Rules in ruleSet:", 2);
        Iterator it = this.ruleSet.iterator();
        while (it.hasNext()) {
            debug(((MultiLabelRule) it.next()).toString(), 2);
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.defaultRule = newDefaultRule();
        this.classifierRandom.setSeed(this.randomSeed);
        MultiLabelLearner multiLabelLearner = (MultiLabelLearner) ((MultiLabelLearner) getPreparedClassOption(this.learnerOption)).copy();
        multiLabelLearner.setRandomSeed(this.randomSeed);
        multiLabelLearner.resetLearning();
        this.defaultRule.setLearner(multiLabelLearner);
        this.defaultRule.setInstanceTransformer(new NoInstanceTransformation());
        setRuleOptions(this.defaultRule);
        this.ruleSet = new MultiLabelRuleSet();
        this.ruleNumberID = 1;
        this.statistics = null;
        this.featureRanking = (FeatureRanking) getPreparedClassOption(this.featureRankingOption);
        setObserver(this.featureRanking);
    }

    protected void setRuleOptions(MultiLabelRule multiLabelRule) {
        multiLabelRule.setSplitCriterion((MultiLabelSplitCriterion) ((MultiLabelSplitCriterion) getPreparedClassOption(this.splitCriterionOption)).copy());
        multiLabelRule.setChangeDetector(((ChangeDetector) getPreparedClassOption(this.changeDetector)).copy());
        multiLabelRule.setAnomalyDetector(((AnomalyDetector) getPreparedClassOption(this.anomalyDetector)).copy());
        multiLabelRule.setNumericObserverOption((NumericStatisticsObserver) ((NumericStatisticsObserver) getPreparedClassOption(this.numericObserverOption)).copy());
        multiLabelRule.setNominalObserverOption((NominalStatisticsObserver) ((NominalStatisticsObserver) getPreparedClassOption(this.nominalObserverOption)).copy());
        multiLabelRule.setErrorMeasurer((MultiLabelErrorMeasurer) ((MultiLabelErrorMeasurer) getPreparedClassOption(this.errorMeasurerOption)).copy());
        multiLabelRule.setOutputAttributesSelector((OutputAttributesSelector) ((OutputAttributesSelector) getPreparedClassOption(this.outputSelectorOption)).copy());
        multiLabelRule.setRandomGenerator(this.classifierRandom);
        multiLabelRule.setAttributesPercentage(this.attributesPercentage);
        multiLabelRule.setInputAttributesSelector((InputAttributesSelector) ((InputAttributesSelector) getPreparedClassOption(this.inputSelectorOption)).copy());
    }

    protected abstract MultiLabelRule newDefaultRule();

    public ErrorWeightedVoteMultiLabel newErrorWeightedVote() {
        return (ErrorWeightedVoteMultiLabel) ((ErrorWeightedVoteMultiLabel) getPreparedClassOption(this.weightedVoteOption)).copy();
    }

    @Override // moa.classifiers.AbstractClassifier, moa.learners.Learner
    public void setRandomSeed(int i) {
        super.setRandomSeed(i);
        this.classifierRandom.setSeed(i);
    }

    public void setObserver(ObserverMOAObject observerMOAObject) {
        this.observer = observerMOAObject;
        this.defaultRule.addObserver(observerMOAObject);
    }
}
