package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.pytorch.jni.IValueUtils;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:lib/pytorch-engine-0.9.0.jar:ai/djl/pytorch/engine/PtSymbolBlock.class */
public class PtSymbolBlock extends NativeResource<Long> implements SymbolBlock {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) PtSymbolBlock.class);
    private PtNDManager manager;
    private boolean isTrain;
    private PairList<String, Shape> inputDescriptions;
    private PairList<String, Shape> outputDescriptions;
    private boolean first;

    public PtSymbolBlock(PtNDManager ptNDManager, long j) {
        super(Long.valueOf(j));
        this.manager = ptNDManager;
        ptNDManager.attach(getUid(), this);
        this.isTrain = true;
        this.first = true;
    }

    @Override // ai.djl.util.NativeResource, java.lang.AutoCloseable, ai.djl.ndarray.NDArray
    public void close() {
        Long l = (Long) this.handle.getAndSet(null);
        if (l != null) {
            JniUtils.deleteModule(l.longValue());
            this.manager.detach(getUid());
            this.manager = null;
        }
    }

    @Override // ai.djl.nn.SymbolBlock
    public void removeLastBlock() {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        if (this.isTrain != z) {
            this.isTrain = z;
            if (this.isTrain) {
                JniUtils.enableTrainingMode(this);
            } else {
                JniUtils.enableInferenceMode(this);
            }
        }
        if (this.first) {
            synchronized (PtSymbolBlock.class) {
                if (this.first) {
                    this.inputDescriptions = new PairList<>();
                    this.outputDescriptions = new PairList<>();
                    Iterator<NDArray> it = nDList.iterator();
                    while (it.hasNext()) {
                        NDArray next = it.next();
                        this.inputDescriptions.add(next.getName(), next.getShape());
                    }
                    NDList forward = IValueUtils.forward(this, nDList, z);
                    Iterator<NDArray> it2 = forward.iterator();
                    while (it2.hasNext()) {
                        NDArray next2 = it2.next();
                        this.outputDescriptions.add(next2.getName(), next2.getShape());
                    }
                    this.first = false;
                    return forward;
                }
            }
        }
        return IValueUtils.forward(this, nDList, z);
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer) {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public void setInitializer(Initializer initializer, String str) {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public Shape[] initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public boolean isInitialized() {
        return true;
    }

    @Override // ai.djl.nn.Block
    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public void clear() {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        if (this.inputDescriptions == null) {
            logger.warn("Input shapes are unknown, please run predict or forward onceand call describeInput again.");
        }
        return this.inputDescriptions;
    }

    @Override // ai.djl.nn.SymbolBlock
    public PairList<String, Shape> describeOutput() {
        if (this.outputDescriptions == null) {
            logger.warn("Output shapes are unknown, please run predict or forward onceand call describeOutput again.");
        }
        return this.outputDescriptions;
    }

    @Override // ai.djl.nn.Block
    public BlockList getChildren() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.nn.Block
    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.nn.Block
    public ParameterList getParameters() {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[0];
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) {
        throw new UnsupportedOperationException("Not supported for PyTorch");
    }
}
