package ai.djl.training.listener;

import ai.djl.TrainingDivergedException;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;

/* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/training/listener/DivergenceCheckTrainingListener.class */
public class DivergenceCheckTrainingListener extends TrainingListenerAdapter {
    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (Float.isNaN(trainer.getLoss().getAccumulator(EvaluatorTrainingListener.TRAIN_ALL))) {
            throw new TrainingDivergedException("The Loss became NaN, try reduce learning rate,add clipGradient option to your optimizer, check input data and loss calculation.");
        }
    }
}
