package meka.classifiers.multitarget;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import meka.classifiers.multilabel.Evaluation;
import meka.classifiers.multilabel.MultiTargetCapable;
import meka.classifiers.multilabel.ProblemTransformationMethod;
import meka.core.A;
import meka.core.MLEvalUtils;
import meka.core.MLUtils;
import meka.core.MatrixUtils;
import meka.core.OptionUtils;
import meka.core.Result;
import meka.core.StatUtils;
import meka.core.SuperLabelUtils;
import meka.filters.multilabel.SuperNodeFilter;
import org.apache.xerces.impl.xs.SchemaSymbols;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;

/* loaded from: input_file:lib/meka-1.9.7.jar:meka/classifiers/multitarget/SCC.class */
public class SCC extends NSR implements Randomizable, MultiTargetClassifier, TechnicalInformationHandler {
    private static final long serialVersionUID = 6517394813440480854L;
    private static final int i_SPLIT = 67;
    private static final String i_ErrFn = "Exact match";
    private SuperNodeFilter f = new SuperNodeFilter();
    private int m_Iv = 0;
    private int m_I = 1000;
    private Random rand = null;

    public SCC() {
        this.m_Classifier = new CC();
    }

    @Override // meka.classifiers.multitarget.NSR, meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "meka.classifiers.multitarget.CC";
    }

    @Override // meka.classifiers.multitarget.NSR, meka.classifiers.multilabel.PS, meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod
    public String globalInfo() {
        return "Super Class Classifier (SCC).\nThe output space is manipulated into super classes (based on label dependence; and pruning and nearest-subset-replacement like NSR), upon which a multi-target base classifier is applied.\nFor example, a super class based on two labels might take values in {[0,3],[0,0],[1,2]}.\nFor more information see:\n" + getTechnicalInformation().toString();
    }

    @Override // meka.classifiers.multilabel.PS, weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Jesse Read, Concha Blieza, Pedro Larranaga");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Multi-Dimensional Classification with Super-Classes");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Knowledge and Data Engineering");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2013");
        return technicalInformation;
    }

    private double rating(int[][] iArr, double[][] dArr) {
        return rating(iArr, dArr, 0.0d);
    }

    private double rating(int[][] iArr, double[][] dArr, double d) {
        int length = dArr.length;
        double[][] dArr2 = new double[length][length];
        boolean[][] zArr = new boolean[length][length];
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            Arrays.sort(iArr[i]);
            double length2 = iArr[i].length;
            for (int i2 = 0; i2 < length2; i2++) {
                for (int i3 = i2 + 1; i3 < length2; i3++) {
                    zArr[iArr[i][i2]][iArr[i][i3]] = true;
                }
            }
        }
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = i4 + 1; i5 < length; i5++) {
                if (zArr[i4][i5]) {
                    d2 += dArr[i4][i5] - d;
                } else {
                    d3 += dArr[i4][i5] - d;
                }
            }
        }
        return d2 - d3;
    }

    private int[][] mutateCombinations(int[][] iArr, Random random) {
        int nextInt = random.nextInt(iArr.length);
        int nextInt2 = random.nextInt(iArr[nextInt].length);
        int nextInt3 = random.nextInt(iArr.length);
        if (nextInt3 == nextInt) {
            iArr = (int[][]) Arrays.copyOf(iArr, iArr.length + 1);
            int length = iArr.length - 1;
            int[] iArr2 = new int[1];
            iArr2[0] = iArr[nextInt][nextInt2];
            iArr[length] = iArr2;
            int length2 = iArr.length + 1;
            iArr[nextInt] = A.delete(iArr[nextInt], nextInt2);
        } else {
            iArr[nextInt3] = A.append(iArr[nextInt3], iArr[nextInt][nextInt2]);
            iArr[nextInt] = A.delete(iArr[nextInt], nextInt2);
        }
        if (iArr[nextInt].length <= 0) {
            iArr[nextInt] = iArr[iArr.length - 1];
            iArr = (int[][]) Arrays.copyOf(iArr, iArr.length - 1);
        }
        return iArr;
    }

    public void trainClassifier(Classifier classifier, Instances instances, int[][] iArr) throws Exception {
        this.f = new SuperNodeFilter();
        this.f.setIndices(iArr);
        this.f.setP(this.m_P >= 0 ? this.m_P : this.rand.nextInt(Math.abs(this.m_P)));
        this.f.setN(this.m_N >= 0 ? this.m_N : this.rand.nextInt(Math.abs(this.m_N)));
        Instances process = this.f.process(instances);
        if (getDebug()) {
            System.out.println("PS(" + this.f.getP() + "," + this.m_N + ") reduced: " + instances.numInstances() + " -> " + process.numInstances() + " / " + MLUtils.numberOfUniqueCombinations(instances) + " -> " + MLUtils.numberOfUniqueCombinations(process));
        }
        this.m_InstancesTemplate = process;
        this.m_Classifier.buildClassifier(process);
    }

    public Result testClassifier(Classifier classifier, Instances instances, Instances instances2, int[][] iArr) throws Exception {
        trainClassifier(this.m_Classifier, instances, iArr);
        Result testClassifier = Evaluation.testClassifier((ProblemTransformationMethod) classifier, instances2);
        if ((classifier instanceof MultiTargetClassifier) || Evaluation.isMT(instances2)) {
            testClassifier.setInfo("Type", "MT");
        } else if (classifier instanceof ProblemTransformationMethod) {
            testClassifier.setInfo("Threshold", MLEvalUtils.getThreshold(testClassifier.predictions, instances, "PCut1"));
            testClassifier.setInfo("Type", "ML");
        }
        testClassifier.setValue("N_train", instances.numInstances());
        testClassifier.setValue("N_test", instances2.numInstances());
        testClassifier.setValue("LCard_train", MLUtils.labelCardinality(instances));
        testClassifier.setValue("LCard_test", MLUtils.labelCardinality(instances2));
        testClassifier.setInfo("Classifier_name", classifier.getClass().getName());
        testClassifier.setInfo("Classifier_info", classifier.toString());
        testClassifier.setInfo("Dataset_name", MLUtils.getDatasetName(instances2));
        testClassifier.output = Result.getStats(testClassifier, SchemaSymbols.ATTVAL_TRUE_1);
        return testClassifier;
    }

    @Override // meka.classifiers.multitarget.NSR, meka.classifiers.multilabel.PS, meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        instances.numInstances();
        MLUtils.numberOfUniqueCombinations(instances);
        int classIndex = instances.classIndex();
        this.rand = new Random(this.m_S);
        if (!(this.m_Classifier instanceof MultiTargetClassifier) && !(this.m_Classifier instanceof MultiTargetCapable)) {
            throw new Exception("[Error] The base classifier must be multi-target capable, i.e., from meka.classifiers.multitarget.");
        }
        Instances instances2 = new Instances(instances);
        instances2.randomize(this.rand);
        Instances instances3 = new Instances(instances2, 0, (instances2.numInstances() * 67) / 100);
        Instances instances4 = new Instances(instances2, instances3.numInstances(), instances2.numInstances() - instances3.numInstances());
        if (getDebug()) {
            System.out.print("1. BUILD & Evaluate BR: ");
        }
        CR cr = new CR();
        cr.setClassifier(((ProblemTransformationMethod) this.m_Classifier).getClassifier());
        Result evaluateModel = Evaluation.evaluateModel(cr, instances3, instances4, "PCut1", "5");
        double doubleValue = ((Double) evaluateModel.getMeasurement(i_ErrFn)).doubleValue();
        if (getDebug()) {
            System.out.println(" " + doubleValue);
        }
        int[][] generatePartition = SuperLabelUtils.generatePartition(A.make_sequence(classIndex), this.rand);
        if (getDebug()) {
            System.out.println("2. GET ERR-CHI-SQUARED MATRIX: ");
        }
        double[][] condDepMatrix = StatUtils.condDepMatrix(instances4, evaluateModel);
        if (getDebug()) {
            System.out.println(MatrixUtils.toString(condDepMatrix));
        }
        if (getDebug()) {
            System.out.println("3. COMBINE NODES TO FIND THE BEST COMBINATION ACCORDING TO CHI");
        }
        double rating = rating(generatePartition, condDepMatrix);
        if (getDebug()) {
            System.out.println("@0 : " + SuperLabelUtils.toString(generatePartition) + "\t(" + rating + ")");
        }
        for (int i = 0; i < this.m_I; i++) {
            int[][] mutateCombinations = mutateCombinations(MatrixUtils.deep_copy(generatePartition), this.rand);
            double rating2 = rating(mutateCombinations, condDepMatrix);
            if (rating2 > rating) {
                generatePartition = mutateCombinations;
                rating = rating2;
                if (getDebug()) {
                    System.out.println("@" + i + " : " + SuperLabelUtils.toString(generatePartition) + "\t(" + rating + ")");
                }
            } else if (2.0d * (1.0d - sigma((Math.abs(rating2 - rating) * i) / 1000.0d)) > this.rand.nextDouble()) {
                if (getDebug()) {
                    System.out.println("@" + i + " : " + SuperLabelUtils.toString(mutateCombinations) + "\t(" + rating2 + ")*");
                }
                generatePartition = mutateCombinations;
                rating = rating2;
            }
        }
        if (this.m_Iv > 0) {
            if (getDebug()) {
                System.out.println("4. REFINING THE INITIAL SET WITH SOME OLD-FASHIONED INTERNAL EVAL");
            }
            double doubleValue2 = ((Double) testClassifier((ProblemTransformationMethod) this.m_Classifier, instances3, instances4, generatePartition).getMeasurement(i_ErrFn)).doubleValue();
            if (getDebug()) {
                System.out.println("@0 : " + SuperLabelUtils.toString(generatePartition) + "\t(" + doubleValue2 + ")");
            }
            for (int i2 = 0; i2 < this.m_Iv; i2++) {
                int[][] mutateCombinations2 = mutateCombinations(MatrixUtils.deep_copy(generatePartition), this.rand);
                trainClassifier(this.m_Classifier, instances3, generatePartition);
                double doubleValue3 = ((Double) testClassifier((ProblemTransformationMethod) this.m_Classifier, instances3, instances4, mutateCombinations2).getMeasurement(i_ErrFn)).doubleValue();
                if (doubleValue3 > doubleValue2) {
                    doubleValue2 = doubleValue3;
                    generatePartition = mutateCombinations2;
                    if (getDebug()) {
                        System.out.println("@" + (i2 + 1) + "' : " + SuperLabelUtils.toString(generatePartition) + "\t(" + doubleValue2 + ")");
                    }
                }
            }
        }
        if (getDebug()) {
            System.out.println("4. TRAIN " + SuperLabelUtils.toString(generatePartition));
        }
        trainClassifier(this.m_Classifier, instances, generatePartition);
        if (getDebug()) {
        }
    }

    @Override // meka.classifiers.multitarget.NSR, meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        int classIndex = instance.classIndex();
        double[] dArr = new double[classIndex * 2];
        int classIndex2 = this.m_InstancesTemplate.classIndex();
        double[] dArr2 = null;
        try {
            dArr2 = ((ProblemTransformationMethod) this.m_Classifier).distributionForInstance(MLUtils.setTemplate(instance, this.f.getTemplate(), this.m_InstancesTemplate));
            for (int i = 0; i < classIndex2; i++) {
                int[] decodeClasses = SuperNodeFilter.decodeClasses(this.m_InstancesTemplate.attribute(i).name());
                String[] decodeValue = SuperNodeFilter.decodeValue(this.m_InstancesTemplate.attribute(i).value((int) Math.round(dArr2[i])));
                for (int i2 = 0; i2 < decodeClasses.length; i2++) {
                    dArr[decodeClasses[i2]] = instance.dataset().attribute(decodeClasses[i2]).indexOfValue(decodeValue[i2]);
                    dArr[decodeClasses[i2] + classIndex] = dArr2[i + classIndex2];
                }
            }
            return dArr;
        } catch (Exception e) {
            System.err.println("EXCEPTION !!! setting to " + Arrays.toString(dArr2));
            return dArr;
        }
    }

    public void setI(int i) {
        this.m_I = i;
    }

    public int getI() {
        return this.m_I;
    }

    public String iTipText() {
        return "the number of simulated annealing iterations";
    }

    public void setIv(int i) {
        this.m_Iv = i;
    }

    public int getIv() {
        return this.m_Iv;
    }

    public String ivTipText() {
        return "the number of internal-validation iterations";
    }

    public static void main(String[] strArr) {
        ProblemTransformationMethod.evaluation(new SCC(), strArr);
    }

    public static final double sigma(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    @Override // meka.classifiers.multilabel.PS, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tSets the number of simulated annealing iterations\n\tdefault: 1000", "I", 1, "-I <value>"));
        vector.addElement(new Option("\tSets the number of internal-validation iterations\n\tdefault: 0", "V", 1, "-V <value>"));
        OptionUtils.add(vector, super.listOptions());
        return OptionUtils.toEnumeration(vector);
    }

    @Override // meka.classifiers.multilabel.PS, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setI(OptionUtils.parse(strArr, 'I', 1000));
        setIv(OptionUtils.parse(strArr, 'V', 0));
        super.setOptions(strArr);
    }

    @Override // meka.classifiers.multilabel.PS, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        OptionUtils.add((List<String>) arrayList, 'I', getI());
        OptionUtils.add((List<String>) arrayList, 'V', getIv());
        OptionUtils.add(arrayList, super.getOptions());
        return OptionUtils.toArray(arrayList);
    }
}
