package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import ai.djl.util.RandomUtils;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/training/dataset/RandomAccessDataset.class */
public abstract class RandomAccessDataset implements Dataset {
    protected Sampler sampler;
    protected Batchifier dataBatchifier;
    protected Batchifier labelBatchifier;
    protected Pipeline pipeline;
    protected Pipeline targetPipeline;
    protected ExecutorService executor;
    protected int prefetchNumber;
    protected long limit;
    protected Device device;

    /* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/training/dataset/RandomAccessDataset$BaseBuilder.class */
    public static abstract class BaseBuilder<T extends BaseBuilder<T>> {
        protected Sampler sampler;
        protected Pipeline pipeline;
        protected Pipeline targetPipeline;
        protected ExecutorService executor;
        protected int prefetchNumber;
        protected Device device;
        protected Batchifier dataBatchifier = Batchifier.STACK;
        protected Batchifier labelBatchifier = Batchifier.STACK;
        protected long limit = Long.MAX_VALUE;

        public Sampler getSampler() {
            Objects.requireNonNull(this.sampler, "The sampler must be set");
            return this.sampler;
        }

        public T setSampling(int i, boolean z) {
            return setSampling(i, z, false);
        }

        public T setSampling(int i, boolean z, boolean z2) {
            if (z) {
                this.sampler = new BatchSampler(new RandomSampler(), i, z2);
            } else {
                this.sampler = new BatchSampler(new SequenceSampler(), i, z2);
            }
            return self();
        }

        public T setSampling(Sampler sampler) {
            this.sampler = sampler;
            return self();
        }

        public T optDataBatchifier(Batchifier batchifier) {
            this.dataBatchifier = batchifier;
            return self();
        }

        public T optLabelBatchifier(Batchifier batchifier) {
            this.labelBatchifier = batchifier;
            return self();
        }

        public T optPipeline(Pipeline pipeline) {
            this.pipeline = pipeline;
            return self();
        }

        public T addTransform(Transform transform) {
            if (this.pipeline == null) {
                this.pipeline = new Pipeline();
            }
            this.pipeline.add(transform);
            return self();
        }

        public T optTargetPipeline(Pipeline pipeline) {
            this.targetPipeline = pipeline;
            return self();
        }

        public T addTargetTransform(Transform transform) {
            if (this.targetPipeline == null) {
                this.targetPipeline = new Pipeline();
            }
            this.targetPipeline.add(transform);
            return self();
        }

        public T optExecutor(ExecutorService executorService, int i) {
            this.executor = executorService;
            this.prefetchNumber = i;
            return self();
        }

        public T optDevice(Device device) {
            this.device = device;
            return self();
        }

        public T optLimit(long j) {
            this.limit = j;
            return self();
        }

        protected abstract T self();
    }

    /* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/training/dataset/RandomAccessDataset$SubDataset.class */
    private static final class SubDataset extends RandomAccessDataset {
        private RandomAccessDataset dataset;
        private int[] indices;
        private int from;
        private int to;

        public SubDataset(RandomAccessDataset randomAccessDataset, int[] iArr, int i, int i2) {
            this.dataset = randomAccessDataset;
            this.indices = iArr;
            this.from = i;
            this.to = i2;
            this.sampler = randomAccessDataset.sampler;
            this.dataBatchifier = randomAccessDataset.dataBatchifier;
            this.labelBatchifier = randomAccessDataset.labelBatchifier;
            this.pipeline = randomAccessDataset.pipeline;
            this.targetPipeline = randomAccessDataset.targetPipeline;
            this.executor = randomAccessDataset.executor;
            this.prefetchNumber = randomAccessDataset.prefetchNumber;
            this.device = randomAccessDataset.device;
            this.limit = Long.MAX_VALUE;
        }

        @Override // ai.djl.training.dataset.RandomAccessDataset
        public Record get(NDManager nDManager, long j) throws IOException {
            if (j >= size()) {
                throw new IndexOutOfBoundsException("index(" + j + ") > size(" + size() + ").");
            }
            return this.dataset.get(nDManager, this.indices[Math.toIntExact(j) + this.from]);
        }

        @Override // ai.djl.training.dataset.RandomAccessDataset
        protected long availableSize() {
            return this.to - this.from;
        }

        @Override // ai.djl.training.dataset.Dataset
        public void prepare(Progress progress) {
        }
    }

    RandomAccessDataset() {
    }

    public RandomAccessDataset(BaseBuilder<?> baseBuilder) {
        this.sampler = baseBuilder.getSampler();
        this.dataBatchifier = baseBuilder.dataBatchifier;
        this.labelBatchifier = baseBuilder.labelBatchifier;
        this.pipeline = baseBuilder.pipeline;
        this.targetPipeline = baseBuilder.targetPipeline;
        this.executor = baseBuilder.executor;
        this.prefetchNumber = baseBuilder.prefetchNumber;
        this.limit = baseBuilder.limit;
        this.device = baseBuilder.device;
    }

    public abstract Record get(NDManager nDManager, long j) throws IOException;

    @Override // ai.djl.training.dataset.Dataset
    public Iterable<Batch> getData(NDManager nDManager) throws IOException, TranslateException {
        prepare();
        return new DataIterable(this, nDManager, this.sampler, this.dataBatchifier, this.labelBatchifier, this.pipeline, this.targetPipeline, this.executor, this.prefetchNumber, this.device);
    }

    public Iterable<Batch> getData(NDManager nDManager, Sampler sampler) throws IOException, TranslateException {
        prepare();
        return new DataIterable(this, nDManager, sampler, this.dataBatchifier, this.labelBatchifier, this.pipeline, this.targetPipeline, this.executor, this.prefetchNumber, this.device);
    }

    public long size() {
        return Math.min(this.limit, availableSize());
    }

    protected abstract long availableSize();

    public RandomAccessDataset[] randomSplit(int... iArr) throws IOException, TranslateException {
        prepare();
        if (iArr.length < 2) {
            throw new IllegalArgumentException("Requires at least two split portion.");
        }
        int intExact = Math.toIntExact(size());
        int[] array = IntStream.range(0, intExact).toArray();
        for (int i = 0; i < intExact; i++) {
            swap(array, i, RandomUtils.nextInt(intExact));
        }
        RandomAccessDataset[] randomAccessDatasetArr = new RandomAccessDataset[iArr.length];
        double sum = Arrays.stream(iArr).sum();
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length - 1; i3++) {
            int i4 = i2 + ((int) ((iArr[i3] / sum) * intExact));
            randomAccessDatasetArr[i3] = new SubDataset(this, array, i2, i4);
            i2 = i4;
        }
        randomAccessDatasetArr[iArr.length - 1] = new SubDataset(this, array, i2, intExact);
        return randomAccessDatasetArr;
    }

    public RandomAccessDataset subDataset(int i, int i2) {
        return new SubDataset(this, IntStream.range(0, Math.toIntExact(size())).toArray(), i, i2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Pair<Number[][], Number[][]> toArray() throws IOException, TranslateException {
        NDManager newBaseManager = NDManager.newBaseManager();
        Throwable th = null;
        try {
            BatchSampler batchSampler = new BatchSampler(new SequenceSampler(), 1, false);
            int intExact = Math.toIntExact(size());
            Number[] numberArr = new Number[intExact];
            Number[] numberArr2 = new Number[intExact];
            int i = 0;
            for (Batch batch : getData(newBaseManager, batchSampler)) {
                numberArr[i] = flattenRecord(batch.getData());
                numberArr2[i] = flattenRecord(batch.getLabels());
                batch.close();
                i++;
            }
            Pair<Number[][], Number[][]> pair = new Pair<>(numberArr, numberArr2);
            if (newBaseManager != null) {
                if (0 != 0) {
                    try {
                        newBaseManager.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    newBaseManager.close();
                }
            }
            return pair;
        } catch (Throwable th3) {
            if (newBaseManager != null) {
                if (0 != 0) {
                    try {
                        newBaseManager.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    newBaseManager.close();
                }
            }
            throw th3;
        }
    }

    private Number[] flattenRecord(NDList nDList) {
        NDList nDList2 = new NDList((Collection<NDArray>) nDList.stream().map((v0) -> {
            return v0.flatten();
        }).collect(Collectors.toList()));
        if (nDList2.size() == 0) {
            return null;
        }
        return nDList2.size() == 1 ? nDList2.get(0).toArray() : NDArrays.concat(nDList2).toArray();
    }

    private static void swap(int[] iArr, int i, int i2) {
        int i3 = iArr[i];
        iArr[i] = iArr[i2];
        iArr[i2] = i3;
    }
}
