package moa.classifiers.rules.multilabel.functions;

import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.Prediction;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiTargetRegressor;
import moa.classifiers.rules.multilabel.errormeasurers.AbstractMultiTargetErrorMeasurer;
import moa.classifiers.rules.multilabel.errormeasurers.MeanAbsoluteDeviationMT;
import moa.classifiers.rules.multilabel.errormeasurers.MultiTargetErrorMeasurer;
import moa.core.Measurement;
import moa.options.ClassOption;
import org.apache.log4j.Priority;

/* loaded from: input_file:lib/moa.jar:moa/classifiers/rules/multilabel/functions/AdaptiveMultiTargetRegressor.class */
public class AdaptiveMultiTargetRegressor extends AbstractMultiLabelLearner implements MultiTargetRegressor, AMRulesFunction {
    private static final long serialVersionUID = 1;
    private static final int NUM_LEARNERS = 2;
    public ClassOption baseLearnerOption1;
    public ClassOption baseLearnerOption2;
    public ClassOption errorMeasurerOption;
    public IntOption randomSeedOption = new IntOption("randomSeedOption", 'r', "randomSeedOption", 1, Priority.ALL_INT, Integer.MAX_VALUE);
    protected boolean hasStarted;
    protected MultiTargetRegressor[] baseLearner;
    protected MultiTargetErrorMeasurer[] errorMeasurer;

    public AdaptiveMultiTargetRegressor() {
        this.randomSeedOption = this.randomSeedOption;
        this.baseLearnerOption1 = new ClassOption("baseLearner1", 'l', "First base learner.", AMRulesFunction.class, MultiTargetMeanRegressor.class.getName());
        this.baseLearnerOption2 = new ClassOption("baseLearner2", 'm', "Second base learner.", AMRulesFunction.class, MultiTargetPerceptronRegressor.class.getName());
        this.errorMeasurerOption = new ClassOption("errorMeasurer", 'e', "Measure of error for deciding which learner should predict.", AbstractMultiTargetErrorMeasurer.class, MeanAbsoluteDeviationMT.class.getName());
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        if (!this.hasStarted) {
            this.baseLearner = new MultiTargetRegressor[2];
            this.errorMeasurer = new MultiTargetErrorMeasurer[2];
            this.baseLearner[0] = (MultiTargetRegressor) getPreparedClassOption(this.baseLearnerOption1);
            this.baseLearner[1] = (MultiTargetRegressor) getPreparedClassOption(this.baseLearnerOption2);
            for (int i = 0; i < 2; i++) {
                if (this.baseLearner[i].isRandomizable()) {
                    this.baseLearner[i].setRandomSeed(this.randomSeed);
                }
                this.baseLearner[i].resetLearning();
                this.errorMeasurer[i] = (MultiTargetErrorMeasurer) ((MultiTargetErrorMeasurer) getPreparedClassOption(this.errorMeasurerOption)).copy();
            }
            this.hasStarted = true;
        }
        for (int i2 = 0; i2 < 2; i2++) {
            Prediction predictionForInstance = this.baseLearner[i2].getPredictionForInstance(multiLabelInstance);
            if (predictionForInstance != null) {
                this.errorMeasurer[i2].addPrediction(predictionForInstance, multiLabelInstance);
            }
            this.baseLearner[i2].trainOnInstanceImpl(multiLabelInstance);
        }
    }

    @Override // moa.classifiers.AbstractMultiLabelLearner, moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        Prediction prediction = null;
        if (this.hasStarted) {
            int i = 0;
            double d = Double.MAX_VALUE;
            for (int i2 = 0; i2 < 2; i2++) {
                double currentError = this.errorMeasurer[i2].getCurrentError();
                if (currentError < d) {
                    d = currentError;
                    i = i2;
                }
            }
            prediction = this.baseLearner[i].getPredictionForInstance(multiLabelInstance);
        }
        return prediction;
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.hasStarted = false;
        if (this.baseLearner != null) {
            for (int i = 0; i < this.baseLearner.length; i++) {
                this.classifierRandom.setSeed(this.randomSeedOption.getValue());
                this.baseLearner[i].setRandomSeed(this.randomSeed);
                this.baseLearner[i].resetLearning();
            }
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Learns two regressors and uses the regressor with less error to predict.";
    }

    @Override // moa.classifiers.rules.multilabel.functions.AMRulesFunction
    public void resetWithMemory() {
        if (this.errorMeasurer == null) {
            this.errorMeasurer = new MultiTargetErrorMeasurer[2];
        }
        for (int i = 0; i < 2; i++) {
            this.errorMeasurer[i] = (MultiTargetErrorMeasurer) ((MultiTargetErrorMeasurer) getPreparedClassOption(this.errorMeasurerOption)).copy();
            if (this.baseLearner[i] instanceof AMRulesFunction) {
                ((AMRulesFunction) this.baseLearner[i]).resetWithMemory();
            }
        }
    }

    @Override // moa.classifiers.rules.multilabel.functions.AMRulesFunction
    public void selectOutputsToLearn(int[] iArr) {
        for (int i = 0; i < 2; i++) {
            ((AMRulesFunction) this.baseLearner[i]).selectOutputsToLearn(iArr);
        }
    }
}
