package moa.classifiers.deeplearning;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import com.yahoo.labs.samoa.instances.Instance;
import weka.core.json.JSONInstances;

/* compiled from: MLP.java */
/* loaded from: input_file:lib/moa.jar:moa/classifiers/deeplearning/MiniBatch.class */
class MiniBatch {
    private transient NDManager trainingNDManager;
    public int miniBatchSize;
    public transient NDArray trainMiniBatchData = null;
    public transient NDArray trainMiniBatchLabels = null;
    public int itemsInMiniBatch = 0;
    public NDList d = null;
    public NDList l = null;

    public MiniBatch(Device device, int i) {
        this.miniBatchSize = 1;
        this.trainingNDManager = NDManager.newBaseManager(device);
        this.miniBatchSize = i;
    }

    public void addToMiniBatch(double[] dArr, double[] dArr2) {
        if (this.itemsInMiniBatch == 0) {
            this.trainMiniBatchData = this.trainingNDManager.create(dArr).toType(DataType.FLOAT32, false);
            this.trainMiniBatchLabels = this.trainingNDManager.create(dArr2);
        } else if (this.itemsInMiniBatch == 1) {
            this.trainMiniBatchData = this.trainMiniBatchData.stack(this.trainingNDManager.create(dArr).toType(DataType.FLOAT32, false), 0);
            this.trainMiniBatchLabels = this.trainMiniBatchLabels.stack(this.trainingNDManager.create(dArr2), 0);
        } else {
            this.trainMiniBatchData = this.trainMiniBatchData.concat(this.trainingNDManager.create(dArr, new Shape(1, dArr.length)).toType(DataType.FLOAT32, false), 0);
            this.trainMiniBatchLabels = this.trainMiniBatchLabels.concat(this.trainingNDManager.create(dArr2, new Shape(1, dArr2.length)), 0);
        }
        this.itemsInMiniBatch++;
        if (this.itemsInMiniBatch == this.miniBatchSize) {
            this.d = new NDList(this.trainMiniBatchData);
            this.l = new NDList(this.trainMiniBatchLabels);
        }
    }

    public void addToMiniBatch(Instance instance) {
        double[] doubleArray = instance.toDoubleArray();
        NDArray type = this.trainingNDManager.create(doubleArray).toType(DataType.FLOAT32, false);
        int length = doubleArray.length - 1;
        NDArray nDArray = type.get("0:" + Integer.toString(length), new Object[0]);
        NDArray nDArray2 = type.get(Integer.toString(length) + JSONInstances.SPARSE_SEPARATOR + Integer.toString(length + 1), new Object[0]);
        if (this.itemsInMiniBatch == 0) {
            this.trainMiniBatchData = nDArray;
            this.trainMiniBatchLabels = nDArray2;
        } else if (this.itemsInMiniBatch == 1) {
            this.trainMiniBatchData = this.trainMiniBatchData.stack(nDArray, 0);
            this.trainMiniBatchLabels = this.trainMiniBatchLabels.stack(nDArray2, 0);
        } else {
            this.trainMiniBatchData = this.trainMiniBatchData.concat(nDArray.reshape(new Shape(1, length)), 0);
            this.trainMiniBatchLabels = this.trainMiniBatchLabels.concat(nDArray2.reshape(new Shape(1, 1)), 0);
        }
        this.itemsInMiniBatch++;
        if (this.itemsInMiniBatch == this.miniBatchSize) {
            this.d = new NDList(this.trainMiniBatchData);
            this.l = new NDList(this.trainMiniBatchLabels);
        }
    }

    public boolean miniBatchFull() {
        return this.itemsInMiniBatch == this.miniBatchSize;
    }

    public void discardMiniBatch() {
        if (this.d != null) {
            this.d.close();
            this.d = null;
        }
        if (this.trainMiniBatchData != null) {
            this.trainMiniBatchData.close();
            this.trainMiniBatchData = null;
        }
        if (this.l != null) {
            this.l.close();
            this.l = null;
        }
        if (this.trainMiniBatchLabels != null) {
            this.trainMiniBatchLabels.close();
            this.trainMiniBatchLabels = null;
        }
        this.trainingNDManager.close();
        this.trainingNDManager = null;
        this.itemsInMiniBatch = 0;
    }
}
