package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.pooling.Pool;

/* loaded from: input_file:lib/model-zoo-0.9.0.jar:ai/djl/basicmodelzoo/cv/classification/LeNet.class */
public final class LeNet {

    /* loaded from: input_file:lib/model-zoo-0.9.0.jar:ai/djl/basicmodelzoo/cv/classification/LeNet$Builder.class */
    public static final class Builder {
        int numLayers = 4;
        int[] numChannels = {6, 16, 120, 84};

        Builder() {
        }

        public Builder setNumChannels(int[] iArr) {
            if (iArr.length != this.numLayers) {
                throw new IllegalArgumentException("number of channels can be equal to " + this.numLayers);
            }
            this.numChannels = iArr;
            return this;
        }

        public Block build() {
            return LeNet.leNet(this);
        }
    }

    private LeNet() {
    }

    public static Block leNet(Builder builder) {
        return new SequentialBlock().add(Conv2d.builder().setKernelShape(new Shape(5, 5)).optPadding(new Shape(2, 2)).optBias(false).setFilters(builder.numChannels[0]).build()).add(Activation::sigmoid).add(Pool.avgPool2dBlock(new Shape(5, 5), new Shape(2, 2), new Shape(2, 2))).add(Conv2d.builder().setKernelShape(new Shape(5, 5)).setFilters(builder.numChannels[1]).build()).add(Activation::sigmoid).add(Pool.avgPool2dBlock(new Shape(5, 5), new Shape(2, 2), new Shape(2, 2))).add(Blocks.batchFlattenBlock()).add(Linear.builder().setUnits(builder.numChannels[2]).build()).add(Activation::sigmoid).add(Linear.builder().setUnits(builder.numChannels[3]).build()).add(Activation::sigmoid).add(Linear.builder().setUnits(10L).build());
    }

    public static Builder builder() {
        return new Builder();
    }
}
