package moa.learners.featureanalysis;

import com.github.javacliparser.FileOption;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.meta.AdaptiveRandomForest;
import moa.core.Measurement;
import moa.core.Utils;
import moa.options.ClassOption;
import org.apache.commons.cli.HelpFormatter;
import org.jfree.chart.axis.ValueAxis;

/* loaded from: input_file:lib/moa.jar:moa/learners/featureanalysis/ClassifierWithFeatureImportance.class */
public class ClassifierWithFeatureImportance extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    protected PrintStream debugStream;
    protected FeatureImportanceClassifier featureImportanceClassifierLearner;
    public ClassOption featureImportanceLearnerOption = new ClassOption("featureImportanceLearner", 'l', "Learner used to build the model from which the feature importances are extracted", FeatureImportanceClassifier.class, "moa.learners.featureanalysis.FeatureImportanceHoeffdingTree");
    public FlagOption doNotNormalizeFeatureScoreOption = new FlagOption("doNotNormalizeFeatureScore", 'n', "If set the feature importances will not be normalized");
    public IntOption windowSizeOption = new IntOption("windowSize", 'w', "The amount of instances seen before inspecting the feature scores again.", ValueAxis.MAXIMUM_TICK_COUNT, 1, Integer.MAX_VALUE);
    public IntOption maxFeaturesDebugOption = new IntOption("maxFeaturesDebug", 'o', "The maximum number of features to show in the debug.", 100, 1, Integer.MAX_VALUE);
    public FileOption debugFileOption = new FileOption("debugFile", 'c', "File to append the feature scores.", "debug.csv", "csv", true);
    public FlagOption doNotOutputResultsToFileOption = new FlagOption("doNotOutputResultsToFile", 'd', "To evaluate the resources usage without writing the feature importance to a file.");
    protected long instancesSeen = 0;
    protected double mean = -1.0d;
    protected double median = -1.0d;
    protected double max = -1.0d;
    protected double min = Double.POSITIVE_INFINITY;
    protected double sum = 0.0d;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Only produces feature scores for tree-based algorithms.";
    }

    protected void createDebugOutputFile() {
        File file;
        if (this.doNotOutputResultsToFileOption.isSet() || (file = this.debugFileOption.getFile()) == null) {
            return;
        }
        try {
            if (file.exists()) {
                this.debugStream = new PrintStream((OutputStream) new FileOutputStream(file, true), true);
            } else {
                this.debugStream = new PrintStream((OutputStream) new FileOutputStream(file), true);
            }
        } catch (Exception e) {
            throw new RuntimeException("Unable to open immediate result file: " + file, e);
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.instancesSeen = 0L;
        this.featureImportanceClassifierLearner = null;
        this.featureImportanceClassifierLearner = (FeatureImportanceClassifier) getPreparedClassOption(this.featureImportanceLearnerOption);
        this.featureImportanceClassifierLearner.resetLearning();
        createDebugOutputFile();
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (this.instancesSeen == 0 && this.debugStream != null) {
            this.debugStream.println(describe());
            this.debugStream.print("instancesSeen,median,mean,max,min,sum");
            for (int i = 0; i < instance.numAttributes() - 1 && i < this.maxFeaturesDebugOption.getValue(); i++) {
                this.debugStream.print(",top" + i);
            }
            for (int i2 = 0; i2 < instance.numAttributes() - 1 && i2 < this.maxFeaturesDebugOption.getValue(); i2++) {
                this.debugStream.print(",score(top" + i2 + ")");
            }
            for (int i3 = 0; i3 < instance.numAttributes() - 1 && i3 < this.maxFeaturesDebugOption.getValue(); i3++) {
                this.debugStream.print(",score(att" + i3 + HelpFormatter.DEFAULT_OPT_PREFIX + instance.attribute(i3).name() + ")");
            }
            this.debugStream.println();
        }
        this.instancesSeen++;
        this.featureImportanceClassifierLearner.trainOnInstance(instance);
        if (this.instancesSeen % this.windowSizeOption.getValue() == 0) {
            double[] featureImportances = this.featureImportanceClassifierLearner.getFeatureImportances(!this.doNotNormalizeFeatureScoreOption.isSet());
            int[] topKFeatures = this.featureImportanceClassifierLearner.getTopKFeatures(instance.numAttributes() - 1, !this.doNotNormalizeFeatureScoreOption.isSet());
            this.median = topKFeatures.length % 2 == 0 ? (featureImportances[topKFeatures[topKFeatures.length / 2]] + featureImportances[topKFeatures[(topKFeatures.length / 2) - 1]]) / 2.0d : featureImportances[topKFeatures[topKFeatures.length / 2]];
            this.mean = Utils.mean(featureImportances);
            this.max = featureImportances[Utils.maxIndex(featureImportances)];
            this.min = featureImportances[Utils.minIndex(featureImportances)];
            this.sum = Utils.sum(featureImportances);
            if (this.debugStream != null) {
                this.debugStream.print(this.instancesSeen + "," + this.median + "," + this.mean + "," + this.max + "," + this.min + "," + this.sum);
                for (int i4 = 0; i4 < topKFeatures.length && i4 < this.maxFeaturesDebugOption.getValue(); i4++) {
                    this.debugStream.print("," + topKFeatures[i4] + HelpFormatter.DEFAULT_OPT_PREFIX + instance.attribute(topKFeatures[i4]).name());
                }
                for (int i5 = 0; i5 < featureImportances.length && i5 < this.maxFeaturesDebugOption.getValue(); i5++) {
                    this.debugStream.print("," + featureImportances[topKFeatures[i5]]);
                }
                for (int i6 = 0; i6 < featureImportances.length && i6 < this.maxFeaturesDebugOption.getValue(); i6++) {
                    this.debugStream.print("," + featureImportances[i6]);
                }
                this.debugStream.println();
            }
        }
    }

    public double[] getCurrentFeatureImportances() {
        return this.featureImportanceClassifierLearner.getFeatureImportances(!this.doNotNormalizeFeatureScoreOption.isSet());
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        return this.featureImportanceClassifierLearner.getVotesForInstance(instance);
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return true;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected String describe() {
        StringBuilder sb = new StringBuilder();
        String name = this.featureImportanceClassifierLearner.getClass().getName();
        sb.append(name.substring(name.lastIndexOf(46) + 1));
        sb.append("_norm");
        sb.append(this.doNotNormalizeFeatureScoreOption.isSet() ? "NO" : "YES");
        if (this.featureImportanceClassifierLearner instanceof FeatureImportanceHoeffdingTree) {
            sb.append("_gp");
            sb.append(((FeatureImportanceHoeffdingTree) this.featureImportanceClassifierLearner).treeLearner.gracePeriodOption.getValue());
            sb.append("_sc");
            sb.append(((FeatureImportanceHoeffdingTree) this.featureImportanceClassifierLearner).treeLearner.splitConfidenceOption.getValue());
            sb.append("_");
            sb.append(((FeatureImportanceHoeffdingTree) this.featureImportanceClassifierLearner).treeLearner.splitCriterionOption.getValueAsCLIString());
            sb.append("_");
            sb.append(((FeatureImportanceHoeffdingTree) this.featureImportanceClassifierLearner).featureImportanceOption.getChosenLabel());
        }
        if (this.featureImportanceClassifierLearner.getSubClassifiers() != null) {
            sb.append("_s");
            sb.append(this.featureImportanceClassifierLearner.getSubClassifiers().length);
            if ((this.featureImportanceClassifierLearner instanceof FeatureImportanceHoeffdingTreeEnsemble) && (((FeatureImportanceHoeffdingTreeEnsemble) this.featureImportanceClassifierLearner).ensemble instanceof AdaptiveRandomForest)) {
                AdaptiveRandomForest adaptiveRandomForest = (AdaptiveRandomForest) ((FeatureImportanceHoeffdingTreeEnsemble) this.featureImportanceClassifierLearner).ensemble;
                sb.append("_m");
                sb.append(adaptiveRandomForest.mFeaturesModeOption.getChosenLabel());
                sb.append("_");
                sb.append(adaptiveRandomForest.mFeaturesPerTreeSizeOption.getValue());
                sb.append("_lambda");
                sb.append(adaptiveRandomForest.lambdaOption.getValue());
                sb.append("_");
                sb.append(adaptiveRandomForest.treeLearnerOption.getValueAsCLIString());
            }
        }
        return sb.toString();
    }
}
