package meka.experiment.evaluators;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import meka.classifiers.multilabel.Evaluation;
import meka.classifiers.multilabel.MultiLabelClassifier;
import meka.core.OptionUtils;
import meka.core.ThreadLimiter;
import meka.core.ThreadUtils;
import meka.experiment.evaluationstatistics.EvaluationStatistics;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.Utils;

/* loaded from: input_file:lib/meka-1.9.7.jar:meka/experiment/evaluators/CrossValidation.class */
public class CrossValidation extends AbstractEvaluator implements Randomizable, ThreadLimiter {
    private static final long serialVersionUID = 6318297857792961890L;
    public static final String KEY_FOLD = "Fold";
    protected int m_ActualNumThreads;
    protected transient ExecutorService m_Executor;
    protected int m_NumFolds = getDefaultNumFolds();
    protected boolean m_PreserveOrder = false;
    protected int m_Seed = getDefaultSeed();
    protected int m_NumThreads = getDefaultNumThreads();
    protected String m_Threshold = getDefaultThreshold();
    protected String m_Verbosity = getDefaultVerbosity();

    @Override // meka.experiment.evaluators.AbstractEvaluator
    public String globalInfo() {
        return "Evaluates the classifier using cross-validation. Order can be preserved.";
    }

    protected int getDefaultNumFolds() {
        return 10;
    }

    public void setNumFolds(int i) {
        if (i >= 2) {
            this.m_NumFolds = i;
        } else {
            System.err.println("Number of folds must >= 2, provided: " + i);
        }
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds to use.";
    }

    public void setPreserveOrder(boolean z) {
        this.m_PreserveOrder = z;
    }

    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, no randomization is occurring and the order in the data is preserved.";
    }

    protected int getDefaultSeed() {
        return 0;
    }

    @Override // weka.core.Randomizable
    public void setSeed(int i) {
        this.m_Seed = i;
    }

    @Override // weka.core.Randomizable
    public int getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed to use for randomization.";
    }

    protected int getDefaultNumThreads() {
        return -1;
    }

    @Override // meka.core.ThreadLimiter
    public void setNumThreads(int i) {
        if (i >= -1) {
            this.m_NumThreads = i;
        } else {
            log("Number of threads must be >= -1, provided: " + i);
        }
    }

    @Override // meka.core.ThreadLimiter
    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return "The number of threads to use ; -1 = number of CPUs/cores; 0 or 1 = sequential execution.";
    }

    protected String getDefaultThreshold() {
        return "PCut1";
    }

    public void setThreshold(String str) {
        this.m_Threshold = str;
    }

    public String getThreshold() {
        return this.m_Threshold;
    }

    public String thresholdTipText() {
        return "The threshold option.";
    }

    protected String getDefaultVerbosity() {
        return "3";
    }

    public void setVerbosity(String str) {
        this.m_Verbosity = str;
    }

    public String getVerbosity() {
        return this.m_Verbosity;
    }

    public String verbosityTipText() {
        return "The verbosity option.";
    }

    @Override // meka.experiment.evaluators.AbstractEvaluator, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        OptionUtils.add(vector, super.listOptions());
        OptionUtils.addOption(vector, numFoldsTipText(), getDefaultNumFolds(), 'F');
        OptionUtils.addFlag(vector, preserveOrderTipText(), 'O');
        OptionUtils.addOption(vector, seedTipText(), getDefaultSeed(), 'S');
        OptionUtils.addOption(vector, thresholdTipText(), getDefaultThreshold(), 'T');
        OptionUtils.addOption(vector, verbosityTipText(), getDefaultVerbosity(), 'V');
        OptionUtils.addOption(vector, numThreadsTipText(), getDefaultNumThreads(), "num-threads");
        return OptionUtils.toEnumeration(vector);
    }

    @Override // meka.experiment.evaluators.AbstractEvaluator, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setNumFolds(OptionUtils.parse(strArr, 'F', getDefaultNumFolds()));
        setPreserveOrder(Utils.getFlag('O', strArr));
        setSeed(OptionUtils.parse(strArr, 'S', getDefaultSeed()));
        setThreshold(OptionUtils.parse(strArr, 'T', getDefaultThreshold()));
        setVerbosity(OptionUtils.parse(strArr, 'V', getDefaultVerbosity()));
        setNumThreads(OptionUtils.parse(strArr, "num-threads", getDefaultNumThreads()));
        super.setOptions(strArr);
    }

    @Override // meka.experiment.evaluators.AbstractEvaluator, weka.core.OptionHandler
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        OptionUtils.add(arrayList, super.getOptions());
        OptionUtils.add((List<String>) arrayList, 'F', getNumFolds());
        OptionUtils.add((List<String>) arrayList, 'O', getPreserveOrder());
        OptionUtils.add((List<String>) arrayList, 'S', getSeed());
        OptionUtils.add((List<String>) arrayList, 'T', getThreshold());
        OptionUtils.add((List<String>) arrayList, 'V', getVerbosity());
        OptionUtils.add((List<String>) arrayList, "num-threads", getNumThreads());
        return OptionUtils.toArray(arrayList);
    }

    protected List<EvaluationStatistics> evaluateSequential(MultiLabelClassifier multiLabelClassifier, Instances instances) {
        ArrayList arrayList = new ArrayList();
        Random random = new Random(this.m_Seed);
        Instances instances2 = new Instances(instances);
        if (!this.m_PreserveOrder) {
            instances2.randomize(random);
        }
        for (int i = 1; i <= this.m_NumFolds; i++) {
            log("Fold: " + i);
            try {
                EvaluationStatistics evaluationStatistics = new EvaluationStatistics(multiLabelClassifier, instances2, Evaluation.evaluateModel((MultiLabelClassifier) OptionUtils.shallowCopy(multiLabelClassifier), this.m_PreserveOrder ? instances2.trainCV(this.m_NumFolds, i - 1) : instances2.trainCV(this.m_NumFolds, i - 1, random), instances2.testCV(this.m_NumFolds, i - 1), this.m_Threshold, this.m_Verbosity));
                evaluationStatistics.put(KEY_FOLD, Integer.valueOf(i));
                arrayList.add(evaluationStatistics);
                if (this.m_Stopped) {
                    break;
                }
            } catch (Exception e) {
                handleException("Failed to evaluate dataset '" + instances.relationName() + "' with classifier: " + Utils.toCommandLine(multiLabelClassifier), e);
            }
        }
        if (this.m_Stopped) {
            arrayList.clear();
        }
        return arrayList;
    }

    protected List<EvaluationStatistics> evaluateParallel(final MultiLabelClassifier multiLabelClassifier, final Instances instances) {
        ArrayList arrayList = new ArrayList();
        debug("pre: create jobs");
        ArrayList arrayList2 = new ArrayList();
        Random random = new Random(this.m_Seed);
        Instances instances2 = new Instances(instances);
        if (!this.m_PreserveOrder) {
            instances2.randomize(random);
        }
        for (int i = 1; i <= this.m_NumFolds; i++) {
            final int i2 = i;
            final Instances trainCV = this.m_PreserveOrder ? instances2.trainCV(this.m_NumFolds, i2 - 1) : instances2.trainCV(this.m_NumFolds, i2 - 1, random);
            final Instances testCV = instances2.testCV(this.m_NumFolds, i2 - 1);
            final MultiLabelClassifier multiLabelClassifier2 = (MultiLabelClassifier) OptionUtils.shallowCopy(multiLabelClassifier);
            arrayList2.add(new EvaluatorJob() { // from class: meka.experiment.evaluators.CrossValidation.1
                @Override // meka.experiment.evaluators.EvaluatorJob
                protected List<EvaluationStatistics> doCall() throws Exception {
                    ArrayList arrayList3 = new ArrayList();
                    CrossValidation.this.log("Executing fold #" + i2 + "...");
                    try {
                        EvaluationStatistics evaluationStatistics = new EvaluationStatistics(multiLabelClassifier, instances, Evaluation.evaluateModel(multiLabelClassifier2, trainCV, testCV, CrossValidation.this.m_Threshold, CrossValidation.this.m_Verbosity));
                        evaluationStatistics.put(CrossValidation.KEY_FOLD, Integer.valueOf(i2));
                        arrayList3.add(evaluationStatistics);
                    } catch (Exception e) {
                        CrossValidation.this.handleException("Failed to evaluate dataset '" + instances.relationName() + "' with classifier: " + Utils.toCommandLine(multiLabelClassifier), e);
                    }
                    CrossValidation.this.log("...finished fold #" + i2);
                    return arrayList3;
                }
            });
        }
        debug("post: create jobs");
        this.m_Executor = Executors.newFixedThreadPool(this.m_ActualNumThreads);
        debug("pre: submit");
        for (int i3 = 0; i3 < arrayList2.size(); i3++) {
            try {
                this.m_Executor.submit((Callable) arrayList2.get(i3));
            } catch (RejectedExecutionException e) {
            } catch (Exception e2) {
                handleException("Failed to start up jobs", e2);
            }
        }
        debug("post: submit");
        debug("pre: shutdown");
        this.m_Executor.shutdown();
        debug("post: shutdown");
        debug("pre: wait");
        while (!this.m_Executor.isTerminated()) {
            try {
                this.m_Executor.awaitTermination(100L, TimeUnit.MILLISECONDS);
            } catch (InterruptedException e3) {
            } catch (Exception e4) {
                handleException("Failed to await termination", e4);
            }
        }
        debug("post: wait");
        debug("pre: collect");
        for (int i4 = 0; i4 < arrayList2.size(); i4++) {
            arrayList.addAll(((EvaluatorJob) arrayList2.get(i4)).getResult());
        }
        debug("post: collect");
        return arrayList;
    }

    @Override // meka.experiment.evaluators.Evaluator
    public List<EvaluationStatistics> evaluate(MultiLabelClassifier multiLabelClassifier, Instances instances) {
        this.m_ActualNumThreads = ThreadUtils.getActualNumThreads(this.m_NumThreads, this.m_NumFolds);
        log("Number of threads (1 = sequential): " + this.m_ActualNumThreads);
        List<EvaluationStatistics> evaluateSequential = this.m_ActualNumThreads == 1 ? evaluateSequential(multiLabelClassifier, instances) : evaluateParallel(multiLabelClassifier, instances);
        if (this.m_Stopped) {
            evaluateSequential.clear();
        }
        return evaluateSequential;
    }

    @Override // meka.experiment.evaluators.AbstractEvaluator, meka.experiment.evaluators.Evaluator
    public void stop() {
        if (this.m_Executor != null) {
            debug("pre: shutdownNow");
            this.m_Executor.shutdownNow();
            debug("post: shutdownNow");
        }
        super.stop();
    }
}
