package meka.classifiers.multitarget;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import meka.classifiers.multilabel.PS;
import meka.classifiers.multilabel.ProblemTransformationMethod;
import meka.core.A;
import meka.core.MLUtils;
import meka.core.PSUtils;
import meka.core.SuperLabelUtils;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;

/* loaded from: input_file:lib/meka-1.9.7.jar:meka/classifiers/multitarget/NSR.class */
public class NSR extends PS implements MultiTargetClassifier {
    private static final long serialVersionUID = 8373228150066785001L;

    public NSR() {
        this.m_Classifier = new J48();
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.J48";
    }

    @Override // meka.classifiers.multilabel.PS, meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod
    public String globalInfo() {
        return "The Nearest Set Relpacement (NSR) method.\nA multi-target version of PS: The nearest sets are used to replace outliers, rather than subsets (as in PS).";
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.setMinimumNumberInstances(1);
        return capabilities;
    }

    @Override // meka.classifiers.multilabel.PS, meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        testCapabilities(instances);
        try {
            this.m_Classifier.buildClassifier(convertInstances(instances, instances.classIndex()));
        } catch (Exception e) {
            if (this.m_P <= 0) {
                throw new Exception("Failed to construct a classifier.");
            }
            this.m_P--;
            System.err.println("Not enough distinct class values, trying again with P = " + this.m_P + " ...");
            buildClassifier(instances);
        }
    }

    @Override // meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        int classIndex = instance.classIndex();
        double[] distributionForInstance = this.m_Classifier.distributionForInstance(PSUtils.convertInstance(instance, classIndex, this.m_InstancesTemplate));
        double[] copyOf = Arrays.copyOf(A.toDoubleArray(MLUtils.decodeValue(this.m_InstancesTemplate.classAttribute().value(Utils.maxIndex(distributionForInstance)))), classIndex * 2);
        HashMap[] hashMapArr = new HashMap[classIndex];
        for (int i = 0; i < classIndex; i++) {
            hashMapArr[i] = new HashMap();
        }
        for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
            double[] doubleArray = A.toDoubleArray(MLUtils.decodeValue(this.m_InstancesTemplate.classAttribute().value(i2)));
            for (int i3 = 0; i3 < doubleArray.length; i3++) {
                hashMapArr[i3].put(Double.valueOf(doubleArray[i3]), Double.valueOf(hashMapArr[i3].containsKey(Double.valueOf(doubleArray[i3])) ? ((Double) hashMapArr[i3].get(Double.valueOf(doubleArray[i3]))).doubleValue() + distributionForInstance[i2] : distributionForInstance[i2]));
            }
        }
        for (int i4 = 0; i4 < classIndex; i4++) {
            copyOf[i4 + classIndex] = hashMapArr[i4].size() > 0 ? ((Double) Collections.max(hashMapArr[i4].values())).doubleValue() : 0.0d;
        }
        return copyOf;
    }

    public double[] convertDistribution(double[] dArr, int i) {
        double[] dArr2 = new double[i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > 0.0d) {
                double[] fromBitString = MLUtils.fromBitString(this.m_InstancesTemplate.classAttribute().value(i2));
                for (int i3 = 0; i3 < fromBitString.length; i3++) {
                    if (fromBitString[i3] > 0.0d) {
                        dArr2[i3] = 1.0d;
                    }
                }
            }
        }
        return dArr2;
    }

    public static String[] getTopNSubsets(String str, final HashMap<String, Integer> hashMap, int i) {
        String[] split = str.split("\\+");
        ArrayList arrayList = new ArrayList();
        for (String str2 : hashMap.keySet()) {
            if (MLUtils.bitDifference(split, str2.split("\\+")) <= 1) {
                arrayList.add(str2);
            }
        }
        Collections.sort(arrayList, new Comparator<String>() { // from class: meka.classifiers.multitarget.NSR.1
            @Override // java.util.Comparator
            public int compare(String str3, String str4) {
                if (((Integer) hashMap.get(str3)).intValue() > ((Integer) hashMap.get(str4)).intValue()) {
                    return -1;
                }
                return ((Integer) hashMap.get(str3)).intValue() > ((Integer) hashMap.get(str4)).intValue() ? 1 : 0;
            }
        });
        String[] strArr = (String[]) arrayList.toArray(new String[arrayList.size()]);
        return (String[]) Arrays.copyOf(strArr, Math.min(i, strArr.length));
    }

    public Instances convertInstances(Instances instances, int i) throws Exception {
        HashMap<String, Integer> classCombinationCounts = MLUtils.classCombinationCounts(instances);
        if (getDebug()) {
            System.out.println("Found " + classCombinationCounts.size() + " unique combinations");
        }
        MLUtils.pruneCountHashMap(classCombinationCounts, this.m_P);
        if (getDebug()) {
            System.out.println("Pruned to " + classCombinationCounts.size() + " with P=" + this.m_P);
        }
        Instances deleteAttributesAt = MLUtils.deleteAttributesAt(new Instances(instances), MLUtils.gen_indices(i));
        deleteAttributesAt.insertAttributeAt(new Attribute("CLASS", new ArrayList(classCombinationCounts.keySet())), 0);
        deleteAttributesAt.setClassIndex(0);
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            String encodeValue = MLUtils.encodeValue(MLUtils.toIntArray(instances.instance(i2), i));
            if (classCombinationCounts.containsKey(encodeValue)) {
                deleteAttributesAt.instance(i2).setClassValue(encodeValue);
            } else if (this.m_N > 0) {
                for (String str : SuperLabelUtils.getTopNSubsets(encodeValue, classCombinationCounts, this.m_N)) {
                    classCombinationCounts.get(str).intValue();
                    Instance instance = (Instance) deleteAttributesAt.instance(i2).copy();
                    instance.setClassValue(str);
                    instance.setWeight(1.0d / r0.length);
                    deleteAttributesAt.add(instance);
                }
            }
        }
        deleteAttributesAt.deleteWithMissingClass();
        this.m_InstancesTemplate = new Instances(deleteAttributesAt, 0);
        if (getDebug()) {
            System.out.println(deleteAttributesAt);
        }
        return deleteAttributesAt;
    }

    public static String[] decodeValue(String str) {
        return str.split("\\+");
    }

    @Override // meka.classifiers.multilabel.PS, meka.classifiers.multilabel.LC, meka.classifiers.multilabel.ProblemTransformationMethod, weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9117 $");
    }

    public static void main(String[] strArr) {
        ProblemTransformationMethod.evaluation(new NSR(), strArr);
    }
}
