package moa.streams.generators.multilabel;

import cern.colt.matrix.impl.AbstractFormatter;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import com.yahoo.labs.samoa.instances.Range;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
import moa.core.Example;
import moa.core.FastVector;
import moa.core.InstanceExample;
import moa.core.MultilabelInstancesHeader;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.AbstractOptionHandler;
import moa.options.ClassOption;
import moa.streams.InstanceStream;
import moa.streams.MultiTargetInstanceStream;
import moa.tasks.TaskMonitor;
import org.apache.xerces.impl.xs.SchemaSymbols;

/* loaded from: input_file:lib/moa.jar:moa/streams/generators/multilabel/MetaMultilabelGenerator.class */
public class MetaMultilabelGenerator extends AbstractOptionHandler implements MultiTargetInstanceStream {
    private static final long serialVersionUID = 1;
    public ClassOption binaryGeneratorOption = new ClassOption("binaryGenerator", 's', "Binary Generator (specify the number of attributes here, but only two classes!).", InstanceStream.class, "generators.RandomTreeGenerator");
    public IntOption metaRandomSeedOption = new IntOption("metaRandomSeed", 'm', "Random seed (for the meta process). Use two streams with the same seed and r > 0.0 in the second stream if you wish to introduce drift to the label dependencies without changing the underlying concept.", 1);
    public IntOption numLabelsOption = new IntOption("numLabels", 'c', "Number of labels.", 10, 2, Integer.MAX_VALUE);
    public IntOption skewOption = new IntOption("skew", 'k', "Skewed label distribution: 1 (default) = yes; 0 = no (relatively uniform) @NOTE: not currently implemented.", 1, 0, 1);
    public FloatOption labelCardinalityOption = new FloatOption("labelCardinality", 'z', "Desired label cardinality (average number of labels per example).", 1.5d, 0.0d, 2.147483647E9d);
    public FloatOption labelCardinalityVarOption = new FloatOption("labelCardinalityVar", 'v', "Desired label cardinality variance (variance of z) @NOTE: not currently implemented.", 1.0d, 0.0d, 2.147483647E9d);
    public FloatOption labelCardinalityRatioOption = new FloatOption("labelDependency", 'u', "Specifies how much label dependency from 0 (total independence) to 1 (full dependence).", 0.25d, 0.0d, 1.0d);
    public FloatOption labelDependencyChangeRatioOption = new FloatOption("labelDependencyRatioChange", 'r', "Each label-pair dependency has a 'r' chance of being modified. Use this option on the second of two streams with the same random seed (-m) to introduce label-dependence drift.", 0.0d, 0.0d, 1.0d);
    protected MultilabelInstancesHeader m_MultilabelInstancesHeader = null;
    protected InstanceStream m_BinaryGenerator = null;
    protected Instances multilabelStreamTemplate = null;
    protected Random m_MetaRandom = new Random();
    protected int m_L = 0;
    protected int m_A = 0;
    protected double[] priors = null;
    protected double[] priors_norm = null;
    protected double[][] Conditional = null;
    protected HashSet[] m_TopCombinations = null;
    LinkedList<Instance>[] queue = null;

    @Override // moa.options.AbstractOptionHandler
    public void prepareForUseImpl(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
        restart();
    }

    @Override // moa.streams.ExampleStream
    public void restart() {
        this.m_L = this.numLabelsOption.getValue();
        if (this.labelCardinalityOption.getValue() > this.m_L) {
            System.err.println("Error: Label cardinality (z) cannot be greater than the number of labels (c)!");
            System.exit(1);
        }
        this.m_BinaryGenerator = (InstanceStream) getPreparedClassOption(this.binaryGeneratorOption);
        this.m_BinaryGenerator.restart();
        this.m_A = this.m_BinaryGenerator.getHeader().numAttributes() - 1;
        this.m_MetaRandom = new Random(this.metaRandomSeedOption.getValue());
        this.queue = new LinkedList[2];
        for (int i = 0; i < this.queue.length; i++) {
            this.queue[i] = new LinkedList<>();
        }
        this.m_MultilabelInstancesHeader = generateMultilabelHeader(this.m_BinaryGenerator.getHeader());
        this.priors = generatePriors(this.m_MetaRandom, this.m_L, this.labelCardinalityOption.getValue(), this.skewOption.getValue() >= 1);
        boolean[][] modifyDependencyMatrix = modifyDependencyMatrix(new boolean[this.m_L][this.m_L], this.labelCardinalityRatioOption.getValue(), this.m_MetaRandom);
        if (this.labelDependencyChangeRatioOption.getValue() > 0.0d) {
            this.priors = modifyPriorVector(this.priors, this.labelDependencyChangeRatioOption.getValue(), this.m_MetaRandom, this.skewOption.getValue() >= 1);
            modifyDependencyMatrix(modifyDependencyMatrix, this.labelDependencyChangeRatioOption.getValue(), this.m_MetaRandom);
        }
        this.Conditional = generateConditional(this.priors, modifyDependencyMatrix);
        this.priors_norm = Arrays.copyOf(this.priors, this.priors.length);
        Utils.normalize(this.priors_norm);
        this.m_TopCombinations = getTopCombinations(this.m_A);
    }

    protected MultilabelInstancesHeader generateMultilabelHeader(Instances instances) {
        Instances instances2 = new Instances(instances, 0, 0);
        instances2.deleteAttributeAt(Integer.valueOf(instances2.numAttributes() - 1));
        FastVector fastVector = new FastVector();
        fastVector.addElement("0");
        fastVector.addElement(SchemaSymbols.ATTVAL_TRUE_1);
        for (int i = 0; i < this.m_L; i++) {
            instances2.insertAttributeAt(new Attribute("class" + i, fastVector), i);
        }
        Range range = new Range(Integer.toString(this.numLabelsOption.getValue()));
        this.multilabelStreamTemplate = instances2;
        this.multilabelStreamTemplate.setRelationName("SYN_Z" + this.labelCardinalityOption.getValue() + "L" + this.m_L + "X" + this.m_A + "S" + this.metaRandomSeedOption.getValue() + ": -C " + this.m_L);
        this.multilabelStreamTemplate.setClassIndex(Integer.MAX_VALUE);
        this.multilabelStreamTemplate.setRangeOutputIndices(range);
        MultilabelInstancesHeader multilabelInstancesHeader = new MultilabelInstancesHeader(this.multilabelStreamTemplate, this.m_L);
        multilabelInstancesHeader.setRangeOutputIndices(range);
        return multilabelInstancesHeader;
    }

    private double[] generatePriors(Random random, int i, double d, boolean z) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = random.nextDouble();
        }
        do {
            double sum = Utils.sum(dArr) / d;
            for (int i3 = 0; i3 < i; i3++) {
                dArr[i3] = Math.min(1.0d, dArr[i3] / sum);
            }
        } while (Utils.sum(dArr) < d);
        return dArr;
    }

    private Instance getNextWithBinary(int i) {
        int i2 = 1000;
        if (this.queue[i].size() > 0) {
            return this.queue[i].remove();
        }
        while (true) {
            int i3 = i2;
            i2--;
            if (i3 <= 0) {
                System.err.println("[Overflow] The binary stream is too skewed, could not get an example of class " + i + "");
                System.exit(1);
                return null;
            }
            Instance data = this.m_BinaryGenerator.nextInstance2().getData();
            int round = (int) Math.round(data.classValue());
            if (i == round) {
                return data;
            }
            if (this.queue[round].size() < 100) {
                this.queue[round].add(data);
            }
        }
    }

    @Override // moa.streams.ExampleStream
    /* renamed from: nextInstance */
    public Example<Instance> nextInstance2() {
        return new InstanceExample(generateMLInstance(generateSet()));
    }

    private HashSet generateSet() {
        int[] iArr = new int[this.m_L];
        int samplePMF = samplePMF(this.priors_norm);
        iArr[samplePMF] = 1;
        Iterator<Integer> it = getShuffledListToLWithoutK(this.m_L, samplePMF).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            iArr[intValue] = joint(intValue, iArr) > this.m_MetaRandom.nextDouble() ? 1 : 0;
        }
        return vector2set(iArr);
    }

    private double joint(int i, int[] iArr) {
        double d = 1.0d;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (i2 != i && iArr[i2] == 1) {
                d *= this.Conditional[i][i2];
            }
        }
        return d;
    }

    private Instance generateMLInstance(HashSet<Integer> hashSet) {
        DenseInstance denseInstance = new DenseInstance(this.multilabelStreamTemplate.numAttributes());
        denseInstance.setDataset(this.multilabelStreamTemplate);
        for (int i = 0; i < this.m_L; i++) {
            denseInstance.setValue(i, 0.0d);
        }
        Iterator<Integer> it = hashSet.iterator();
        while (it.hasNext()) {
            denseInstance.setValue(it.next().intValue(), 1.0d);
        }
        Instance nextWithBinary = getNextWithBinary(0);
        Instance nextWithBinary2 = getNextWithBinary(1);
        for (int i2 = 0; i2 < this.m_A; i2++) {
            if (hashSet.containsAll(this.m_TopCombinations[i2])) {
                denseInstance.setValue(this.m_L + i2, nextWithBinary2.value(i2));
            } else {
                denseInstance.setValue(this.m_L + i2, nextWithBinary.value(i2));
            }
        }
        return denseInstance;
    }

    private int samplePMF(double[] dArr) {
        double nextDouble = this.m_MetaRandom.nextDouble();
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i];
            if (nextDouble < d) {
                return i;
            }
        }
        return -1;
    }

    protected double[] modifyPriorVector(double[] dArr, double d, Random random, boolean z) {
        for (int i = 0; i < dArr.length; i++) {
            if (random.nextDouble() < d) {
                dArr[i] = random.nextDouble();
            }
        }
        return dArr;
    }

    protected boolean[][] modifyDependencyMatrix(boolean[][] zArr, double d, Random random) {
        for (int i = 0; i < zArr.length; i++) {
            for (int i2 = i + 1; i2 < zArr[i].length; i2++) {
                if (random.nextDouble() <= d) {
                    zArr[i][i2] = !r0[r1];
                }
            }
        }
        return zArr;
    }

    protected double[][] generateConditional(double[] dArr, boolean[][] zArr) {
        int length = dArr.length;
        double[][] dArr2 = new double[length][length];
        for (int i = 0; i < length; i++) {
            dArr2[i][i] = dArr[i];
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            for (int i3 = i2 + 1; i3 < dArr2[i2].length; i3++) {
                if (zArr[i2][i3]) {
                    dArr2[i2][i3] = this.m_MetaRandom.nextBoolean() ? min(dArr[i2], dArr[i3]) : max(dArr[i2], dArr[i3]);
                    dArr2[i3][i2] = (dArr2[i2][i3] * dArr2[i2][i2]) / dArr2[i3][i3];
                } else {
                    dArr2[i2][i3] = dArr[i2];
                    dArr2[i3][i2] = (dArr2[i2][i3] * dArr[i3]) / dArr[i2];
                }
            }
        }
        return dArr2;
    }

    private HashSet[] getTopCombinations(int i) {
        final HashMap hashMap = new HashMap();
        new HashMap();
        double d = 0.0d;
        for (int i2 = 0; i2 < 100000; i2++) {
            HashSet generateSet = generateSet();
            d += generateSet.size();
            hashMap.put(generateSet, Integer.valueOf(hashMap.get(generateSet) != null ? ((Integer) hashMap.get(generateSet)).intValue() + 1 : 1));
        }
        double d2 = d / 100000;
        ArrayList arrayList = new ArrayList(hashMap.keySet());
        Collections.sort(arrayList, new Comparator<HashSet>() { // from class: moa.streams.generators.multilabel.MetaMultilabelGenerator.1
            @Override // java.util.Comparator
            public int compare(HashSet hashSet, HashSet hashSet2) {
                return ((Integer) hashMap.get(hashSet2)).compareTo((Integer) hashMap.get(hashSet));
            }
        });
        System.err.println("The most common labelsets (from which we will build the map) will likely be: ");
        HashSet[] hashSetArr = new HashSet[i];
        double[] dArr = new double[i];
        int i3 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            System.err.println(" " + ((HashSet) it.next()) + " : " + ((((Integer) hashMap.get(r0)).intValue() * 100.0d) / 100000) + "%");
            int i4 = i3;
            i3++;
            dArr[i4] = ((Integer) hashMap.get(r0)).intValue();
            if (i3 == dArr.length) {
                break;
            }
        }
        Utils.sum(dArr);
        System.err.println("Estimated Label Cardinality:  " + d2 + AbstractFormatter.DEFAULT_SLICE_SEPARATOR);
        System.err.println("Estimated % Unique Labelsets: " + ((hashMap.size() * 100.0d) / 100000) + "%\n\n");
        Utils.normalize(dArr);
        int i5 = 0;
        for (int i6 = 0; i6 < arrayList.size() && i5 < hashSetArr.length; i6++) {
            int round = (int) Math.round(Math.max(dArr[i6] * hashSetArr.length, 1.0d));
            for (int i7 = 0; i7 < round && i5 < hashSetArr.length; i7++) {
                int i8 = i5;
                i5++;
                hashSetArr[i8] = (HashSet) arrayList.get(i6);
            }
        }
        Collections.shuffle(Arrays.asList(hashSetArr), this.m_MetaRandom);
        return hashSetArr;
    }

    @Override // moa.streams.ExampleStream
    public InstancesHeader getHeader() {
        return this.m_MultilabelInstancesHeader;
    }

    @Override // moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Generates a multi-label stream based on a binary random generator.";
    }

    @Override // moa.streams.ExampleStream
    public long estimatedRemainingInstances() {
        return -1L;
    }

    @Override // moa.streams.ExampleStream
    public boolean hasMoreInstances() {
        return true;
    }

    @Override // moa.streams.ExampleStream
    public boolean isRestartable() {
        return true;
    }

    @Override // moa.MOAObject
    public void getDescription(StringBuilder sb, int i) {
    }

    private int[] set2vector(HashSet<Integer> hashSet, int i) {
        int[] iArr = new int[i];
        Iterator<Integer> it = hashSet.iterator();
        while (it.hasNext()) {
            iArr[it.next().intValue()] = 1;
        }
        return iArr;
    }

    private HashSet<Integer> vector2set(int[] iArr) {
        HashSet<Integer> hashSet = new HashSet<>();
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] > 0) {
                hashSet.add(Integer.valueOf(i));
            }
        }
        return hashSet;
    }

    private double max(double d, double d2) {
        return Math.min(1.0d, d2 / d);
    }

    private double min(double d, double d2) {
        return Math.max(0.0d, (-1.0d) + d + d2);
    }

    private ArrayList<Integer> getShuffledListToLWithoutK(int i, int i2) {
        ArrayList<Integer> arrayList = new ArrayList<>(i - 1);
        for (int i3 = 0; i3 < i; i3++) {
            if (i3 != i2) {
                arrayList.add(Integer.valueOf(i3));
            }
        }
        Collections.shuffle(arrayList, this.m_MetaRandom);
        return arrayList;
    }

    public static void main(String[] strArr) {
    }

    private void printMatrix(double[][] dArr) {
        System.out.println("--- MATRIX ---");
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                System.out.print(" " + Utils.doubleToString(dArr[i][i2], 5, 3));
            }
            System.out.println("");
        }
    }

    private void printVector(double[] dArr) {
        System.out.println("--- VECTOR ---");
        for (double d : dArr) {
            System.out.print(" " + Utils.doubleToString(d, 5, 3));
        }
        System.out.println("");
    }
}
