package meka.classifiers.multilabel;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
import meka.classifiers.multitarget.CR;
import meka.core.MatrixUtils;
import meka.core.OptionUtils;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.matrix.Matrix;
import weka.core.matrix.SingularValueDecomposition;
import weka.gui.beans.xml.XMLBeans;

/* loaded from: input_file:lib/meka-1.9.7.jar:meka/classifiers/multilabel/PLST.class */
public class PLST extends LabelTransformationClassifier implements TechnicalInformationHandler {
    private static final long serialVersionUID = 3761303322465321039L;
    protected Matrix m_Shift;
    protected Instances m_PatternInstances;
    protected Matrix m_v = null;
    protected int m_Size = getDefaultSize();

    public String globalInfo() {
        return "PLST - Principle Label Space Transformation. Uses SVD to generate a matrix that transforms the label space. This implementation is adapted from the MatLab implementation provided by the authors.\n\nhttps://github.com/hsuantien/mlc_lsdr\n\nFor more information see:\n " + getTechnicalInformation();
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    protected Classifier getDefaultClassifier() {
        CR cr = new CR();
        cr.setClassifier(new LinearRegression());
        return cr;
    }

    protected int getDefaultSize() {
        return 3;
    }

    public int getSize() {
        return this.m_Size;
    }

    public void setSize(int i) {
        this.m_Size = i;
    }

    public String sizeTipText() {
        return "Size of the compressed matrix. Should be \nless than the number of labels and more than 1.";
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Farbound Tai and Hsuan-Tien Lin");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Multilabel classification with principal label space transformation");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Neural Computation");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2012");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "2508-2542");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "24");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "9");
        return technicalInformation;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        OptionUtils.addOption(vector, sizeTipText(), getDefaultSize(), XMLBeans.VAL_SIZE);
        OptionUtils.add(vector, super.listOptions());
        return OptionUtils.toEnumeration(vector);
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        OptionUtils.add((List<String>) arrayList, XMLBeans.VAL_SIZE, getSize());
        OptionUtils.add(arrayList, super.getOptions());
        return OptionUtils.toArray(arrayList);
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setSize(OptionUtils.parse(strArr, XMLBeans.VAL_SIZE, getDefaultSize()));
        super.setOptions(strArr);
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    public Instances transformLabels(Instances instances) throws Exception {
        Instances extractPart = extractPart(instances, false);
        Instances extractPart2 = extractPart(instances, true);
        Matrix instancesToMatrix = MatrixUtils.instancesToMatrix(extractPart2);
        double[] dArr = new double[extractPart2.numAttributes()];
        for (int i = 0; i < extractPart2.numAttributes(); i++) {
            double[] attributeToDoubleArray = extractPart2.attributeToDoubleArray(i);
            double d = 0.0d;
            for (int i2 = 0; i2 < attributeToDoubleArray.length; i2++) {
                if (attributeToDoubleArray[i2] == 1.0d) {
                    d += 1.0d;
                } else {
                    d -= 1.0d;
                    instancesToMatrix.set(i2, i, -1.0d);
                }
            }
            dArr[i] = d / attributeToDoubleArray.length;
        }
        this.m_Shift = new Matrix(new double[][]{dArr});
        double[][] dArr2 = new double[extractPart2.numInstances()][extractPart2.numAttributes()];
        for (int i3 = 0; i3 < extractPart2.numInstances(); i3++) {
            dArr2[i3] = dArr;
        }
        this.m_v = new SingularValueDecomposition(instancesToMatrix.minus(new Matrix(dArr2))).getV();
        double[][] dArr3 = new double[this.m_v.getRowDimension()][getSize()];
        for (int i4 = 0; i4 < dArr3.length; i4++) {
            for (int i5 = 0; i5 < dArr3[i4].length; i5++) {
                dArr3[i4][i5] = this.m_v.getArray()[i4][i5];
            }
        }
        this.m_v = new Matrix(dArr3);
        Matrix times = MatrixUtils.instancesToMatrix(extractPart2).times(this.m_v);
        ArrayList arrayList = new ArrayList();
        for (int i6 = 0; i6 < times.getColumnDimension(); i6++) {
            arrayList.add(new Attribute("att" + i6));
        }
        this.m_PatternInstances = new Instances("compressedlabels", (ArrayList<Attribute>) arrayList, times.getRowDimension());
        Instances mergeInstances = Instances.mergeInstances(MatrixUtils.matrixToInstances(times, this.m_PatternInstances), extractPart);
        mergeInstances.setClassIndex(getSize());
        return mergeInstances;
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    public double[] transformPredictionsBack(double[] dArr) {
        double[] dArr2 = new double[dArr.length / 2];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = dArr[dArr2.length + i];
        }
        Matrix plus = new Matrix(new double[][]{dArr2}).times(this.m_v.transpose()).plus(this.m_Shift);
        double[] dArr3 = new double[plus.getColumnDimension()];
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            dArr3[i2] = plus.getArray()[0][i2] < 0.0d ? 0.0d : 1.0d;
        }
        return dArr3;
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    public Instance transformInstance(Instance instance) throws Exception {
        Instances instances = new Instances(instance.dataset());
        instances.delete();
        instances.add(instance);
        Instances extractPart = extractPart(instances, false);
        Instances instances2 = new Instances(this.m_PatternInstances);
        instances2.add((Instance) new DenseInstance(instances2.numAttributes()));
        Instances mergeInstances = Instances.mergeInstances(instances2, extractPart);
        mergeInstances.setClassIndex(instances2.numAttributes());
        return mergeInstances.instance(0);
    }

    @Override // meka.classifiers.MultiXClassifier
    public String getModel() {
        return "";
    }

    public String toString() {
        return getModel();
    }

    public static void main(String[] strArr) throws Exception {
        AbstractMultiLabelClassifier.evaluation(new PLST(), strArr);
    }
}
