package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;

/* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/nn/core/Linear.class */
public class Linear extends AbstractBlock {
    private static final byte VERSION = 4;
    private long units;
    private long inputFeatures;
    private Shape inputShape;
    private Parameter weight;
    private Parameter bias;

    /* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/nn/core/Linear$Builder.class */
    public static final class Builder {
        private long units;
        private boolean bias = true;

        Builder() {
        }

        public Builder setUnits(long j) {
            this.units = j;
            return this;
        }

        public Builder optBias(boolean z) {
            this.bias = z;
            return this;
        }

        public Linear build() {
            Preconditions.checkArgument(this.units > 0, "You must specify unit");
            return new Linear(this);
        }
    }

    Linear(Builder builder) {
        super((byte) 4);
        this.units = builder.units;
        this.weight = addParameter((Linear) new Parameter("weight", this, ParameterType.WEIGHT), shapeArr -> {
            return new Shape(this.units, this.inputFeatures);
        });
        if (builder.bias) {
            this.bias = addParameter((Linear) new Parameter("bias", this, ParameterType.BIAS), new Shape(this.units));
        }
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        Device device = singletonOrThrow.getDevice();
        return linear(singletonOrThrow, parameterStore.getValue(this.weight, device, z), parameterStore.getValue(this.bias, device, z));
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{this.inputShape.addAll(new Shape(this.units))};
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        return new PairList<>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override // ai.djl.nn.AbstractBlock
    public void beforeInitialize(Shape[] shapeArr) {
        this.inputShapes = shapeArr;
        Shape shape = shapeArr[0];
        this.inputFeatures = shape.get(shape.dimension() - 1);
        this.inputShape = shape.slice(0, shape.dimension() - 1);
    }

    @Override // ai.djl.nn.AbstractBlock
    protected void saveMetadata(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeLong(this.units);
        dataOutputStream.writeLong(this.inputFeatures);
        dataOutputStream.write(this.inputShape.getEncoded());
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b < 1 || b > 4) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
        if (b == 4) {
            this.units = dataInputStream.readLong();
            this.inputFeatures = dataInputStream.readLong();
        } else if (b == 2) {
            if (dataInputStream.readBoolean()) {
                throw new IllegalArgumentException("flatten is not supported!");
            }
            this.inputFeatures = dataInputStream.readLong();
        } else if (b == 3) {
            this.units = dataInputStream.readLong();
            if (dataInputStream.readBoolean()) {
                throw new IllegalArgumentException("flatten is not supported!");
            }
            this.inputFeatures = dataInputStream.readLong();
        } else {
            this.inputFeatures = Shape.decode(dataInputStream).size();
        }
        this.inputShape = Shape.decode(dataInputStream);
    }

    public static NDList linear(NDArray nDArray, NDArray nDArray2) {
        return linear(nDArray, nDArray2, null);
    }

    public static NDList linear(NDArray nDArray, NDArray nDArray2, NDArray nDArray3) {
        return nDArray.getNDArrayInternal().linear(nDArray, nDArray2, nDArray3);
    }

    public static Builder builder() {
        return new Builder();
    }
}
