package moa.tasks;

import com.github.javacliparser.IntOption;
import moa.classifiers.Classifier;
import moa.classifiers.MultiLabelClassifier;
import moa.classifiers.MultiTargetRegressor;
import moa.classifiers.rules.multilabel.functions.MultiLabelNaiveBayes;
import moa.core.ObjectRepository;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.streams.ExampleStream;
import moa.streams.InstanceStream;
import moa.streams.MultiTargetInstanceStream;
import org.apache.commons.math3.distribution.PoissonDistribution;

/* loaded from: input_file:lib/moa.jar:moa/tasks/LearnModelMultiLabel.class */
public class LearnModelMultiLabel extends MultiLabelMainTask {
    private static final long serialVersionUID = 1;
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Learner to train.", MultiLabelClassifier.class, MultiLabelNaiveBayes.class.getName());
    public ClassOption streamOption = new ClassOption("stream", 's', "Stream to learn from.", MultiTargetInstanceStream.class, "MultiTargetArffFileStream");
    public IntOption maxInstancesOption = new IntOption("maxInstances", 'm', "Maximum number of instances to train on per pass over the data.", PoissonDistribution.DEFAULT_MAX_ITERATIONS, 0, Integer.MAX_VALUE);
    public IntOption numPassesOption = new IntOption("numPasses", 'p', "The number of passes to do over the data.", 1, 1, Integer.MAX_VALUE);
    public IntOption memCheckFrequencyOption = new IntOption("memCheckFrequency", 'q', "How many instances between memory bound checks.", 100000, 0, Integer.MAX_VALUE);

    @Override // moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Learns a model from a stream.";
    }

    public LearnModelMultiLabel() {
    }

    public LearnModelMultiLabel(Classifier classifier, InstanceStream instanceStream, int i, int i2) {
        this.learnerOption.setCurrentObject(classifier);
        this.streamOption.setCurrentObject(instanceStream);
        this.maxInstancesOption.setValue(i);
        this.numPassesOption.setValue(i2);
    }

    @Override // moa.tasks.Task
    public Class<?> getTaskResultType() {
        return MultiTargetRegressor.class;
    }

    @Override // moa.tasks.MainTask
    public Object doMainTask(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
        Learner learner = (Learner) getPreparedClassOption(this.learnerOption);
        ExampleStream exampleStream = (ExampleStream) getPreparedClassOption(this.streamOption);
        learner.setModelContext(exampleStream.getHeader());
        int value = this.numPassesOption.getValue();
        int value2 = this.maxInstancesOption.getValue();
        for (int i = 0; i < value; i++) {
            long j = 0;
            taskMonitor.setCurrentActivity("Training learner" + (value > 1 ? " (pass " + (i + 1) + "/" + value + ")" : "") + "...", -1.0d);
            if (i > 0) {
                exampleStream.restart();
            }
            while (exampleStream.hasMoreInstances() && (value2 < 0 || j < value2)) {
                learner.trainOnInstance(exampleStream.nextInstance2());
                j++;
                if (j % 10 == 0) {
                    if (taskMonitor.taskShouldAbort()) {
                        return null;
                    }
                    long estimatedRemainingInstances = exampleStream.estimatedRemainingInstances();
                    if (value2 > 0) {
                        long j2 = value2 - j;
                        if (estimatedRemainingInstances < 0 || j2 < estimatedRemainingInstances) {
                            estimatedRemainingInstances = j2;
                        }
                    }
                    taskMonitor.setCurrentActivityFractionComplete(estimatedRemainingInstances < 0 ? -1.0d : j / (j + estimatedRemainingInstances));
                    if (taskMonitor.resultPreviewRequested()) {
                        taskMonitor.setLatestResultPreview(learner.copy());
                    }
                }
            }
        }
        learner.setModelContext(exampleStream.getHeader());
        return learner;
    }
}
