package moa.classifiers.rules.multilabel.core;

import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceInformation;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Arrays;
import java.util.LinkedList;
import moa.classifiers.MultiLabelLearner;
import moa.classifiers.rules.core.NumericRulePredicate;
import moa.classifiers.rules.core.Utils;
import moa.classifiers.rules.multilabel.attributeclassobservers.AttributeStatisticsObserver;
import moa.classifiers.rules.multilabel.attributeclassobservers.NominalStatisticsObserver;
import moa.classifiers.rules.multilabel.attributeclassobservers.NumericStatisticsObserver;
import moa.classifiers.rules.multilabel.core.splitcriteria.MultiLabelSplitCriterion;
import moa.classifiers.rules.multilabel.functions.AMRulesFunction;
import moa.classifiers.rules.multilabel.instancetransformers.InstanceOutputAttributesSelector;
import moa.core.AutoExpandVector;
import moa.core.DoubleVector;
import moa.core.ObjectRepository;
import moa.tasks.TaskMonitor;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/rules/multilabel/core/LearningLiteralClassification.class */
public class LearningLiteralClassification extends LearningLiteral {
    private static final long serialVersionUID = 1;
    double[] EntropyShift;

    public LearningLiteralClassification() {
    }

    public LearningLiteralClassification(int[] iArr) {
        super(iArr);
    }

    @Override // moa.classifiers.rules.multilabel.core.LearningLiteral
    protected double[] getNormalizedErrors(Prediction prediction, Instance instance) {
        double[] dArr = new double[this.outputsToLearn.length];
        for (int i = 0; i < this.outputsToLearn.length; i++) {
            dArr[i] = Math.abs(normalizeOutputValue(i, prediction.getVote(this.outputsToLearn[i], 0)) - normalizeOutputValue(i, instance.valueOutputAttribute(this.outputsToLearn[i])));
        }
        return dArr;
    }

    private double normalizeOutputValue(int i, double d) {
        double value = this.literalStatistics[i].getValue(1) / this.literalStatistics[i].getValue(0);
        double computeSD = Utils.computeSD(this.literalStatistics[i].getValue(2), this.literalStatistics[i].getValue(1), this.literalStatistics[i].getValue(0));
        double d2 = 0.0d;
        if (computeSD > 1.0E-7d) {
            d2 = (d - value) / computeSD;
        }
        return d2;
    }

    @Override // moa.options.AbstractOptionHandler
    protected void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
    }

    @Override // moa.classifiers.rules.multilabel.core.LearningLiteral
    public boolean tryToExpand(double d, double d2) {
        DoubleVector[] branchStatistics;
        boolean z = false;
        AttributeExpansionSuggestion[] bestSplitSuggestions = getBestSplitSuggestions(this.splitCriterion);
        double d3 = 0.0d;
        this.meritPerInput = new double[this.attributesMask.length];
        for (int i = 0; i < bestSplitSuggestions.length; i++) {
            double merit = bestSplitSuggestions[i].getMerit();
            if (merit > 0.0d) {
                this.meritPerInput[bestSplitSuggestions[i].predicate.getAttributeIndex()] = merit;
                d3 += merit;
            }
        }
        if (d3 == 0.0d) {
            this.meritPerInput = null;
        }
        Arrays.sort(bestSplitSuggestions);
        int[] iArr = (int[]) this.inputsToLearn.clone();
        this.inputsToLearn = this.inputSelector.getNextInputIndices(bestSplitSuggestions);
        Arrays.sort(this.inputsToLearn);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (this.attributesMask[iArr[i2]] && Arrays.binarySearch(this.inputsToLearn, iArr[i2]) < 0) {
                this.attributeObservers.set(iArr[i2], null);
            }
        }
        if (bestSplitSuggestions.length < 2) {
            this.bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
            z = true;
        } else {
            double computeHoeffdingBound = computeHoeffdingBound(this.splitCriterion.getRangeOfMerit(this.literalStatistics), d, this.weightSeen);
            this.bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
            if (this.bestSuggestion.merit - bestSplitSuggestions[bestSplitSuggestions.length - 2].merit > computeHoeffdingBound || computeHoeffdingBound < d2) {
                z = true;
            }
        }
        if (z) {
            DoubleVector[][] resultingNodeStatistics = this.bestSuggestion.getResultingNodeStatistics();
            double[] branchesSplitMerits = this.splitCriterion.getBranchesSplitMerits(resultingNodeStatistics);
            if (branchesSplitMerits[1] > branchesSplitMerits[0]) {
                this.bestSuggestion.getPredicate().negateCondition();
                branchStatistics = getBranchStatistics(resultingNodeStatistics, 1);
            } else {
                branchStatistics = getBranchStatistics(resultingNodeStatistics, 0);
            }
            int[] nextOutputIndices = this.outputSelector.getNextOutputIndices(branchStatistics, this.literalStatistics, this.outputsToLearn);
            Arrays.sort(nextOutputIndices);
            this.otherBranchLearningLiteral = new LearningLiteralClassification();
            this.otherBranchLearningLiteral.instanceHeader = this.instanceHeader;
            this.otherBranchLearningLiteral.learner = (MultiLabelLearner) this.learner.copy();
            this.otherBranchLearningLiteral.instanceTransformer = this.instanceTransformer;
            if (this.learner instanceof AMRulesFunction) {
                if (nextOutputIndices.length != this.outputsToLearn.length) {
                    int[] complementSet = Utils.complementSet(this.outputsToLearn, nextOutputIndices);
                    if (complementSet.length > 0) {
                        this.otherOutputsLearningLiteral = new LearningLiteralClassification(complementSet);
                        MultiLabelLearner multiLabelLearner = (MultiLabelLearner) this.learner.copy();
                        ((AMRulesFunction) multiLabelLearner).selectOutputsToLearn(Utils.getIndexCorrespondence(this.outputsToLearn, complementSet));
                        ((AMRulesFunction) multiLabelLearner).resetWithMemory();
                        this.otherOutputsLearningLiteral.learner = multiLabelLearner;
                        this.otherOutputsLearningLiteral.instanceHeader = this.instanceHeader;
                        this.otherOutputsLearningLiteral.instanceTransformer = new InstanceOutputAttributesSelector(this.instanceHeader, complementSet);
                    }
                    ((AMRulesFunction) this.learner).selectOutputsToLearn(Utils.getIndexCorrespondence(this.outputsToLearn, nextOutputIndices));
                }
                ((AMRulesFunction) this.learner).resetWithMemory();
            } else {
                if (nextOutputIndices.length != this.outputsToLearn.length) {
                    int[] complementSet2 = Utils.complementSet(this.outputsToLearn, nextOutputIndices);
                    if (complementSet2.length > 0) {
                        this.otherOutputsLearningLiteral = new LearningLiteralClassification();
                        MultiLabelLearner multiLabelLearner2 = (MultiLabelLearner) this.learner.copy();
                        multiLabelLearner2.resetLearning();
                        this.otherOutputsLearningLiteral.learner = multiLabelLearner2;
                        this.otherOutputsLearningLiteral.instanceHeader = this.instanceHeader;
                        this.otherOutputsLearningLiteral.instanceTransformer = new InstanceOutputAttributesSelector(this.instanceHeader, complementSet2);
                    }
                }
                this.learner.resetLearning();
            }
            this.expandedLearningLiteral = new LearningLiteralClassification(nextOutputIndices);
            this.expandedLearningLiteral.learner = (MultiLabelLearner) this.learner.copy();
            this.expandedLearningLiteral.instanceHeader = this.instanceHeader;
            this.expandedLearningLiteral.instanceTransformer = new InstanceOutputAttributesSelector(this.instanceHeader, nextOutputIndices);
        }
        return z;
    }

    private DoubleVector[] getBranchStatistics(DoubleVector[][] doubleVectorArr, int i) {
        DoubleVector[] doubleVectorArr2 = new DoubleVector[doubleVectorArr.length];
        for (int i2 = 0; i2 < doubleVectorArr.length; i2++) {
            doubleVectorArr2[i2] = doubleVectorArr[i2][i];
        }
        return doubleVectorArr2;
    }

    private AttributeExpansionSuggestion[] getBestSplitSuggestions(MultiLabelSplitCriterion multiLabelSplitCriterion) {
        AttributeStatisticsObserver attributeStatisticsObserver;
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < this.inputsToLearn.length; i++) {
            if (this.attributesMask[this.inputsToLearn[i]] && (attributeStatisticsObserver = this.attributeObservers.get(this.inputsToLearn[i])) != null) {
                AttributeExpansionSuggestion bestEvaluatedSplitSuggestion = attributeStatisticsObserver.getBestEvaluatedSplitSuggestion(multiLabelSplitCriterion, this.literalStatistics, this.inputsToLearn[i]);
                if (bestEvaluatedSplitSuggestion == null) {
                    bestEvaluatedSplitSuggestion = new AttributeExpansionSuggestion(new NumericRulePredicate(this.inputsToLearn[i], 0.0d, true), null, -1.7976931348623157E308d);
                }
                linkedList.add(bestEvaluatedSplitSuggestion);
            }
        }
        return (AttributeExpansionSuggestion[]) linkedList.toArray(new AttributeExpansionSuggestion[linkedList.size()]);
    }

    @Override // moa.classifiers.rules.multilabel.core.LearningLiteral
    public void trainOnInstance(MultiLabelInstance multiLabelInstance) {
        int initializeAttibutesMask = this.attributesMask == null ? initializeAttibutesMask(multiLabelInstance) : 0;
        int numberOutputTargets = multiLabelInstance.numberOutputTargets();
        if (!this.hasStarted) {
            if (this.learner.isRandomizable()) {
                this.learner.setRandomSeed(this.randomGenerator.nextInt());
            }
            if (this.outputsToLearn == null) {
                this.outputsToLearn = new int[numberOutputTargets];
                for (int i = 0; i < numberOutputTargets; i++) {
                    this.outputsToLearn[i] = i;
                }
            }
            if (this.inputsToLearn == null) {
                this.inputsToLearn = new int[initializeAttibutesMask];
                int i2 = 0;
                for (int i3 = 0; i3 < multiLabelInstance.numInputAttributes(); i3++) {
                    if (this.attributesMask[i3]) {
                        this.inputsToLearn[i2] = i3;
                        i2++;
                    }
                }
            }
            this.literalStatistics = new DoubleVector[this.outputsToLearn.length];
            this.EntropyShift = new double[this.outputsToLearn.length];
            for (int i4 = 0; i4 < this.outputsToLearn.length; i4++) {
                this.literalStatistics[i4] = new DoubleVector(new double[5]);
                this.EntropyShift[i4] = multiLabelInstance.valueOutputAttribute(this.outputsToLearn[i4]);
            }
            this.instanceHeader = (InstancesHeader) multiLabelInstance.dataset();
            this.hasStarted = true;
        }
        double weight = multiLabelInstance.weight();
        DoubleVector[] doubleVectorArr = new DoubleVector[this.outputsToLearn.length];
        for (int i5 = 0; i5 < this.outputsToLearn.length; i5++) {
            double valueOutputAttribute = multiLabelInstance.valueOutputAttribute(this.outputsToLearn[i5]);
            doubleVectorArr[i5] = new DoubleVector(new double[]{weight, weight * valueOutputAttribute, weight * valueOutputAttribute * valueOutputAttribute, (weight * valueOutputAttribute) - this.EntropyShift[i5], weight * (valueOutputAttribute - this.EntropyShift[i5]) * (valueOutputAttribute - this.EntropyShift[i5])});
            this.literalStatistics[i5].addValues(doubleVectorArr[i5].getArrayRef());
        }
        if (this.attributeObservers == null) {
            this.attributeObservers = new AutoExpandVector<>();
        }
        for (int i6 = 0; i6 < this.inputsToLearn.length; i6++) {
            if (this.attributesMask[this.inputsToLearn[i6]]) {
                AttributeStatisticsObserver attributeStatisticsObserver = this.attributeObservers.get(this.inputsToLearn[i6]);
                if (attributeStatisticsObserver == null) {
                    if (multiLabelInstance.attribute(this.inputsToLearn[i6]).isNumeric()) {
                        attributeStatisticsObserver = (NumericStatisticsObserver) this.numericStatisticsObserver.copy();
                    } else if (multiLabelInstance.attribute(this.inputsToLearn[i6]).isNominal()) {
                        attributeStatisticsObserver = (NominalStatisticsObserver) this.nominalStatisticsObserver.copy();
                    }
                    this.attributeObservers.set(this.inputsToLearn[i6], attributeStatisticsObserver);
                }
                attributeStatisticsObserver.observeAttribute(multiLabelInstance.valueInputAttribute(this.inputsToLearn[i6]), doubleVectorArr);
            }
        }
        Instance sourceInstanceToTarget = this.instanceTransformer.sourceInstanceToTarget(multiLabelInstance);
        Prediction predictionForInstance = this.learner.getPredictionForInstance(sourceInstanceToTarget);
        Prediction targetPredictionToSource = predictionForInstance != null ? this.instanceTransformer.targetPredictionToSource(predictionForInstance) : null;
        if (targetPredictionToSource != null) {
            this.errorMeasurer.addPrediction(targetPredictionToSource, multiLabelInstance);
        }
        this.learner.trainOnInstance(sourceInstanceToTarget);
        this.weightSeen += multiLabelInstance.weight();
    }

    @Override // moa.classifiers.rules.multilabel.core.LearningLiteral
    public String getStaticOutput(InstanceInformation instanceInformation) {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.literalStatistics != null) {
            for (int i = 0; i < this.literalStatistics.length; i++) {
                stringBuffer.append(instanceInformation.outputAttribute(this.outputsToLearn[i]).name() + ": " + (this.literalStatistics[i].getValue(1) / this.literalStatistics[i].getValue(0)) + " ");
            }
        }
        return stringBuffer.toString();
    }
}
