package ai.djl.training;

import ai.djl.Device;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Optimizer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:lib/api-0.9.0.jar:ai/djl/training/DefaultTrainingConfig.class */
public class DefaultTrainingConfig implements TrainingConfig {
    private Device[] devices;
    private Loss loss;
    private Initializer initializer = new XavierInitializer(XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2.0f);
    private Optimizer optimizer = Adam.builder().build();
    private DataManager dataManager = DataManager.DEFAULT_DATA_MANAGER;
    private List<Evaluator> evaluators = new ArrayList();
    private List<TrainingListener> listeners = new ArrayList();

    public DefaultTrainingConfig(Loss loss) {
        this.loss = loss;
    }

    public DefaultTrainingConfig optInitializer(Initializer initializer) {
        this.initializer = initializer;
        return this;
    }

    public DefaultTrainingConfig optDevices(Device[] deviceArr) {
        this.devices = deviceArr;
        return this;
    }

    public DefaultTrainingConfig optOptimizer(Optimizer optimizer) {
        this.optimizer = optimizer;
        return this;
    }

    public DefaultTrainingConfig optDataManager(DataManager dataManager) {
        this.dataManager = dataManager;
        return this;
    }

    public DefaultTrainingConfig addEvaluator(Evaluator evaluator) {
        this.evaluators.add(evaluator);
        return this;
    }

    public DefaultTrainingConfig addTrainingListeners(TrainingListener... trainingListenerArr) {
        this.listeners.addAll(Arrays.asList(trainingListenerArr));
        return this;
    }

    @Override // ai.djl.training.TrainingConfig
    public Device[] getDevices() {
        return this.devices == null ? Device.getDevices(Integer.MAX_VALUE) : this.devices;
    }

    @Override // ai.djl.training.TrainingConfig
    public Initializer getInitializer() {
        return this.initializer;
    }

    @Override // ai.djl.training.TrainingConfig
    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    @Override // ai.djl.training.TrainingConfig
    public Loss getLossFunction() {
        return this.loss;
    }

    @Override // ai.djl.training.TrainingConfig
    public DataManager getDataManager() {
        return this.dataManager;
    }

    @Override // ai.djl.training.TrainingConfig
    public List<Evaluator> getEvaluators() {
        return this.evaluators;
    }

    @Override // ai.djl.training.TrainingConfig
    public List<TrainingListener> getTrainingListeners() {
        return this.listeners;
    }
}
