package org.kramerlab.autoencoder.neuralnet.rbm;

import org.kramerlab.autoencoder.math.matrix.Mat;
import org.kramerlab.autoencoder.math.optimization.ResultSelector;
import org.kramerlab.autoencoder.math.optimization.TerminationCriterion;
import org.kramerlab.autoencoder.neuralnet.FullBipartiteConnection;
import org.kramerlab.autoencoder.neuralnet.Layer;
import org.kramerlab.autoencoder.neuralnet.NeuralNet;
import org.kramerlab.autoencoder.visualization.TrainingObserver;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.immutable.C$colon$colon;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.Ordering;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;
import weka.gui.beans.xml.XMLBeans;

/* compiled from: Rbm.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005-h\u0001B\u0001\u0003\u00015\u00111A\u00152n\u0015\t\u0019A!A\u0002sE6T!!\u0002\u0004\u0002\u00139,WO]1m]\u0016$(BA\u0004\t\u0003-\tW\u000f^8f]\u000e|G-\u001a:\u000b\u0005%Q\u0011!C6sC6,'\u000f\\1c\u0015\u0005Y\u0011aA8sO\u000e\u00011c\u0001\u0001\u000f%A\u0011q\u0002E\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\n\u001d\u0016,(/\u00197OKR\u00042aD\n\u0016\u0013\t!BAA\u0007OKV\u0014\u0018\r\u001c(fi2K7.\u001a\t\u0003-\u0001i\u0011A\u0001\u0005\t1\u0001\u0011)\u0019!C\u00013\u00059a/[:jE2,W#\u0001\u000e\u0011\u0005YY\u0012B\u0001\u000f\u0003\u0005!\u0011&-\u001c'bs\u0016\u0014\b\u0002\u0003\u0010\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u000e\u0002\u0011YL7/\u001b2mK\u0002B\u0001\u0002\t\u0001\u0003\u0006\u0004%\t!I\u0001\u000bG>tg.Z2uS>tW#\u0001\u0012\u0011\u0005=\u0019\u0013B\u0001\u0013\u0005\u0005]1U\u000f\u001c7CSB\f'\u000f^5uK\u000e{gN\\3di&|g\u000e\u0003\u0005'\u0001\t\u0005\t\u0015!\u0003#\u0003-\u0019wN\u001c8fGRLwN\u001c\u0011\t\u0011!\u0002!Q1A\u0005\u0002e\ta\u0001[5eI\u0016t\u0007\u0002\u0003\u0016\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u000e\u0002\u000f!LG\rZ3oA!)A\u0006\u0001C\u0001[\u00051A(\u001b8jiz\"B!\u0006\u00180a!)\u0001d\u000ba\u00015!)\u0001e\u000ba\u0001E!)\u0001f\u000ba\u00015!)!\u0007\u0001C!g\u0005)!-^5mIR\u0011Q\u0003\u000e\u0005\u0006kE\u0002\rAN\u0001\u0003YN\u00042aN!E\u001d\tAdH\u0004\u0002:y5\t!H\u0003\u0002<\u0019\u00051AH]8pizJ\u0011!P\u0001\u0006g\u000e\fG.Y\u0005\u0003\u007f\u0001\u000bq\u0001]1dW\u0006<WMC\u0001>\u0013\t\u00115I\u0001\u0003MSN$(BA A!\tyQ)\u0003\u0002G\t\t)A*Y=fe\")\u0001\n\u0001C\u0001\u0013\u0006Y1m\u001c8gC\n,H.\u0019;f)\rQ%\u000b\u0016\t\u0003\u0017Bk\u0011\u0001\u0014\u0006\u0003\u001b:\u000ba!\\1ue&D(BA(\u0007\u0003\u0011i\u0017\r\u001e5\n\u0005Ec%aA'bi\")1k\u0012a\u0001\u0015\u0006\u0001\u0002.\u001b3eK:\f5\r^5wCRLwN\u001c\u0005\b+\u001e\u0003\n\u00111\u0001W\u0003\r\u001a\u0018-\u001c9mKZK7/\u001b2mKVs\u0017\u000e^:EKR,'/\\5oSN$\u0018nY1mYf\u0004\"a\u0016-\u000e\u0003\u0001K!!\u0017!\u0003\u000f\t{w\u000e\\3b]\")1\f\u0001C\u00019\u0006iq-\u001b2cgN\u000bW\u000e\u001d7j]\u001e$B!\u00181cOB!qK\u0018&K\u0013\ty\u0006I\u0001\u0004UkBdWM\r\u0005\u0006Cj\u0003\rAS\u0001\u000em&\u001c\u0018N\u00197f'R\fG/Z:\t\u000f\rT\u0006\u0013!a\u0001I\u0006)1\u000f^3qgB\u0011q+Z\u0005\u0003M\u0002\u00131!\u00138u\u0011\u001d)&\f%AA\u0002YCQ!\u001b\u0001\u0005\u0012)\f\u0011#\u001a=ue\u0006\u001cGo\u0015;bi&\u001cH/[2t)\rYgn\u001c\t\u0006/2T%JS\u0005\u0003[\u0002\u0013a\u0001V;qY\u0016\u001c\u0004\"B1i\u0001\u0004Q\u0005\"\u00029i\u0001\u0004Q\u0015!\u00055jI\u0012,g.Q2uSZ\fG/[8og\")!\u000f\u0001C\tg\u0006)2m\u001c8ue\u0006\u001cH/\u001b<f\t&4XM]4f]\u000e,G\u0003B6um^DQ!^9A\u0002)\u000b\u0011\"\\5oS\n\fGo\u00195\t\u000f\r\f\b\u0013!a\u0001I\"9Q+\u001dI\u0001\u0002\u00041\u0006\"B=\u0001\t\u0003Q\u0018!\u0002;sC&tWcA>\u0002\nQYA0a\u0007\u0002 \u0005%\u00121HA&)\t)R\u0010C\u0004\u007fq\u0006\u0005\t9A@\u0002\u0015\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007E\u00038\u0003\u0003\t)!C\u0002\u0002\u0004\r\u0013\u0001b\u0014:eKJLgn\u001a\t\u0005\u0003\u000f\tI\u0001\u0004\u0001\u0005\u000f\u0005-\u0001P1\u0001\u0002\u000e\t9a)\u001b;oKN\u001c\u0018\u0003BA\b\u0003+\u00012aVA\t\u0013\r\t\u0019\u0002\u0011\u0002\b\u001d>$\b.\u001b8h!\r9\u0016qC\u0005\u0004\u00033\u0001%aA!os\"1\u0011Q\u0004=A\u0002)\u000b1\u0002\u001e:bS:LgnZ*fi\"9\u0011\u0011\u0005=A\u0002\u0005\r\u0012!D2p]\u001aLw-\u001e:bi&|g\u000eE\u0002\u0017\u0003KI1!a\n\u0003\u0005a\u0011&-\u001c+sC&t\u0017N\\4D_:4\u0017nZ;sCRLwN\u001c\u0005\n\u0003WA\b\u0013!a\u0001\u0003[\t\u0011\u0003\u001e:bS:LgnZ(cg\u0016\u0014h/\u001a:t!\u00119\u0014)a\f\u0011\t\u0005E\u0012qG\u0007\u0003\u0003gQ1!!\u000e\u0007\u000351\u0018n];bY&T\u0018\r^5p]&!\u0011\u0011HA\u001a\u0005A!&/Y5oS:<wJY:feZ,'\u000fC\u0004\u0002>a\u0004\r!a\u0010\u0002)Q,'/\\5oCRLwN\\\"sSR,'/[8o!\u0019\t\t%a\u0012\u0016I6\u0011\u00111\t\u0006\u0004\u0003\u000br\u0015\u0001D8qi&l\u0017N_1uS>t\u0017\u0002BA%\u0003\u0007\u0012A\u0003V3s[&t\u0017\r^5p]\u000e\u0013\u0018\u000e^3sS>t\u0007bBA'q\u0002\u0007\u0011qJ\u0001\u000fe\u0016\u001cX\u000f\u001c;TK2,7\r^8s!\u001d\t\t%!\u0015\u0016\u0003\u000bIA!a\u0015\u0002D\tq!+Z:vYR\u001cV\r\\3di>\u0014\bbBA,\u0001\u0011\u0005\u0013\u0011L\u0001\ti>\u001cFO]5oOR\u0011\u00111\f\t\u0005\u0003;\n9'\u0004\u0002\u0002`)!\u0011\u0011MA2\u0003\u0011a\u0017M\\4\u000b\u0005\u0005\u0015\u0014\u0001\u00026bm\u0006LA!!\u001b\u0002`\t11\u000b\u001e:j]\u001eDq!!\u001c\u0001\t\u0003\ny'A\u0003dY>tW\rF\u0001\u0016\u0011\u001d\t\u0019\b\u0001C\u0001\u0003k\nAB]3j]&$\u0018.\u00197ju\u0016$2!FA<\u0011!\tI(!\u001dA\u0002\u0005\r\u0012AB2p]\u001aLw\rC\u0005\u0002~\u0001\t\n\u0011\"\u0001\u0002��\u00059r-\u001b2cgN\u000bW\u000e\u001d7j]\u001e$C-\u001a4bk2$HEM\u000b\u0003\u0003\u0003S3\u0001ZABW\t\t)\t\u0005\u0003\u0002\b\u0006EUBAAE\u0015\u0011\tY)!$\u0002\u0013Ut7\r[3dW\u0016$'bAAH\u0001\u0006Q\u0011M\u001c8pi\u0006$\u0018n\u001c8\n\t\u0005M\u0015\u0011\u0012\u0002\u0012k:\u001c\u0007.Z2lK\u00124\u0016M]5b]\u000e,\u0007\"CAL\u0001E\u0005I\u0011AAM\u0003]9\u0017N\u00192t'\u0006l\u0007\u000f\\5oO\u0012\"WMZ1vYR$3'\u0006\u0002\u0002\u001c*\u001aa+a!\t\u0013\u0005}\u0005!%A\u0005\u0002\u0005\u0005\u0016a\u0004;sC&tG\u0005Z3gCVdG\u000fJ\u001a\u0016\t\u0005\r\u0016qU\u000b\u0003\u0003KSC!!\f\u0002\u0004\u0012A\u00111BAO\u0005\u0004\ti\u0001C\u0005\u0002,\u0002\t\n\u0011\"\u0001\u0002\u001a\u0006)2m\u001c8gC\n,H.\u0019;fI\u0011,g-Y;mi\u0012\u0012\u0004\"CAX\u0001E\u0005I\u0011CA@\u0003}\u0019wN\u001c;sCN$\u0018N^3ESZ,'oZ3oG\u0016$C-\u001a4bk2$HE\r\u0005\n\u0003g\u0003\u0011\u0013!C\t\u00033\u000bqdY8oiJ\f7\u000f^5wK\u0012Kg/\u001a:hK:\u001cW\r\n3fM\u0006,H\u000e\u001e\u00134\u000f\u001d\t9L\u0001E\u0001\u0003s\u000b1A\u00152n!\r1\u00121\u0018\u0004\u0007\u0003\tA\t!!0\u0014\r\u0005m\u0016qXAc!\r9\u0016\u0011Y\u0005\u0004\u0003\u0007\u0004%AB!osJ+g\rE\u0002X\u0003\u000fL1!!3A\u00051\u0019VM]5bY&T\u0018M\u00197f\u0011\u001da\u00131\u0018C\u0001\u0003\u001b$\"!!/\t\u0011\u0005E\u00171\u0018C\u0001\u0003'\facY;u\t\u0006$\u0018-\u00138u_6Kg.\u001b2bi\u000eDWm\u001d\u000b\u0007\u0003+\f9.a7\u0011\u0007]\n%\nC\u0004\u0002Z\u0006=\u0007\u0019\u0001&\u0002\t\u0011\fG/\u0019\u0005\b\u0003;\fy\r1\u0001e\u00035i\u0017N\\5cCR\u001c\u0007nU5{K\"Q\u0011\u0011]A^\u0003\u0003%I!a9\u0002\u0017I,\u0017\r\u001a*fg>dg/\u001a\u000b\u0003\u0003K\u0004B!!\u0018\u0002h&!\u0011\u0011^A0\u0005\u0019y%M[3di\u0002")
/* loaded from: input_file:lib/autoencoder-0.1.jar:org/kramerlab/autoencoder/neuralnet/rbm/Rbm.class */
public class Rbm extends NeuralNet {
    private final RbmLayer visible;
    private final FullBipartiteConnection connection;
    private final RbmLayer hidden;

    public static List<Mat> cutDataIntoMinibatches(Mat mat, int i) {
        return Rbm$.MODULE$.cutDataIntoMinibatches(mat, i);
    }

    public RbmLayer visible() {
        return this.visible;
    }

    public FullBipartiteConnection connection() {
        return this.connection;
    }

    public RbmLayer hidden() {
        return this.hidden;
    }

    @Override // org.kramerlab.autoencoder.neuralnet.NeuralNet, org.kramerlab.autoencoder.neuralnet.NeuralNetLike
    public NeuralNet build(List<Layer> list) {
        if (list instanceof C$colon$colon) {
            C$colon$colon c$colon$colon = (C$colon$colon) list;
            Layer layer = (Layer) c$colon$colon.hd$1();
            List tl$1 = c$colon$colon.tl$1();
            if (layer instanceof RbmLayer) {
                RbmLayer rbmLayer = (RbmLayer) layer;
                if (tl$1 instanceof C$colon$colon) {
                    C$colon$colon c$colon$colon2 = (C$colon$colon) tl$1;
                    Layer layer2 = (Layer) c$colon$colon2.hd$1();
                    List tl$12 = c$colon$colon2.tl$1();
                    if (layer2 instanceof FullBipartiteConnection) {
                        FullBipartiteConnection fullBipartiteConnection = (FullBipartiteConnection) layer2;
                        if (tl$12 instanceof C$colon$colon) {
                            C$colon$colon c$colon$colon3 = (C$colon$colon) tl$12;
                            Layer layer3 = (Layer) c$colon$colon3.hd$1();
                            List tl$13 = c$colon$colon3.tl$1();
                            if (layer3 instanceof RbmLayer) {
                                RbmLayer rbmLayer2 = (RbmLayer) layer3;
                                Nil$ nil$ = Nil$.MODULE$;
                                if (nil$ != null ? nil$.equals(tl$13) : tl$13 == null) {
                                    return new Rbm(rbmLayer, fullBipartiteConnection, rbmLayer2);
                                }
                            }
                        }
                    }
                }
            }
        }
        throw new IllegalArgumentException(new StringBuilder().append((Object) "Cannot build an Rbm from ").append(list).toString());
    }

    public Mat confabulate(Mat mat, boolean z) {
        Mat reverse = reverse(hidden().sample(mat));
        return z ? reverse : visible().sample(reverse);
    }

    public boolean confabulate$default$2() {
        return false;
    }

    public Tuple2<Mat, Mat> gibbsSampling(Mat mat, int i, boolean z) {
        Mat apply = mo2126apply(mat);
        return i == 0 ? new Tuple2<>(mat, apply) : gibbsSampling(confabulate(apply, z), i - 1, z);
    }

    public int gibbsSampling$default$2() {
        return 1;
    }

    public boolean gibbsSampling$default$3() {
        return false;
    }

    public Tuple3<Mat, Mat, Mat> extractStatistics(Mat mat, Mat mat2) {
        int height = mat.height();
        return new Tuple3<>(mat.sumRows().$div2(height), mat2.sumRows().$div2(height), mat.transpose().$times(mat2).$div2(height));
    }

    public Tuple3<Mat, Mat, Mat> contrastiveDivergence(Mat mat, int i, boolean z) {
        Tuple2<Mat, Mat> gibbsSampling = gibbsSampling(mat, 0, z);
        if (gibbsSampling == null) {
            throw new MatchError(gibbsSampling);
        }
        Tuple2 tuple2 = new Tuple2(gibbsSampling.mo2366_1(), gibbsSampling.mo2365_2());
        Mat mat2 = (Mat) tuple2.mo2366_1();
        Mat mat3 = (Mat) tuple2.mo2365_2();
        Tuple2<Mat, Mat> gibbsSampling2 = gibbsSampling(confabulate(mat3, z), i - 1, z);
        if (gibbsSampling2 == null) {
            throw new MatchError(gibbsSampling2);
        }
        Tuple2 tuple22 = new Tuple2(gibbsSampling2.mo2366_1(), gibbsSampling2.mo2365_2());
        Mat mat4 = (Mat) tuple22.mo2366_1();
        Mat mat5 = (Mat) tuple22.mo2365_2();
        Tuple3<Mat, Mat, Mat> extractStatistics = extractStatistics(mat2, mat3);
        if (extractStatistics == null) {
            throw new MatchError(extractStatistics);
        }
        Tuple3 tuple3 = new Tuple3(extractStatistics._1(), extractStatistics._2(), extractStatistics._3());
        Mat mat6 = (Mat) tuple3._1();
        Mat mat7 = (Mat) tuple3._2();
        Mat mat8 = (Mat) tuple3._3();
        Tuple3<Mat, Mat, Mat> extractStatistics2 = extractStatistics(mat4, mat5);
        if (extractStatistics2 == null) {
            throw new MatchError(extractStatistics2);
        }
        Tuple3 tuple32 = new Tuple3(extractStatistics2._1(), extractStatistics2._2(), extractStatistics2._3());
        Mat mat9 = (Mat) tuple32._1();
        Mat mat10 = (Mat) tuple32._2();
        return new Tuple3<>(mat6.$minus(mat9), mat8.$minus((Mat) tuple32._3()), mat7.$minus(mat10));
    }

    public int contrastiveDivergence$default$2() {
        return 1;
    }

    public boolean contrastiveDivergence$default$3() {
        return false;
    }

    public <Fitness> Rbm train(Mat mat, RbmTrainingConfiguration rbmTrainingConfiguration, List<TrainingObserver> list, TerminationCriterion<Rbm, Object> terminationCriterion, ResultSelector<Rbm, Fitness> resultSelector, Ordering<Fitness> ordering) {
        resultSelector.consider(new Rbm$$anonfun$train$1(this));
        ObjectRef objectRef = new ObjectRef(visible().parameters().zero2());
        ObjectRef objectRef2 = new ObjectRef(hidden().parameters().zero2());
        ObjectRef objectRef3 = new ObjectRef(connection().parameters().zero2());
        System.currentTimeMillis();
        IntRef intRef = new IntRef(0);
        while (!BoxesRunTime.unboxToBoolean(terminationCriterion.mo2140apply(this, BoxesRunTime.boxToInteger(intRef.elem)))) {
            double momentum = rbmTrainingConfiguration.momentum(intRef.elem);
            mat.shuffleRows();
            Rbm$.MODULE$.cutDataIntoMinibatches(mat, rbmTrainingConfiguration.minibatchSize()).foreach(new Rbm$$anonfun$train$2(this, mat, rbmTrainingConfiguration, list, objectRef, objectRef2, objectRef3, intRef, momentum, new IntRef(0)));
            resultSelector.consider(new Rbm$$anonfun$train$3(this));
            intRef.elem++;
        }
        return resultSelector.result();
    }

    public <Fitness> List<TrainingObserver> train$default$3() {
        return Nil$.MODULE$;
    }

    @Override // org.kramerlab.autoencoder.neuralnet.NeuralNet, scala.Function1
    public String toString() {
        return new StringBuilder().append((Object) "RBM[visible: ").append(BoxesRunTime.boxToInteger(visible().parameters().height())).append((Object) XMLBeans.VAL_X).append(BoxesRunTime.boxToInteger(visible().parameters().width())).append((Object) " (").append((Object) visible().getClass().getSimpleName()).append((Object) ")").append((Object) " connections: ").append(BoxesRunTime.boxToInteger(connection().parameters().height())).append((Object) XMLBeans.VAL_X).append(BoxesRunTime.boxToInteger(connection().parameters().width())).append((Object) " hidden:").append(BoxesRunTime.boxToInteger(hidden().parameters().height())).append((Object) XMLBeans.VAL_X).append(BoxesRunTime.boxToInteger(hidden().parameters().width())).append((Object) " (").append((Object) hidden().getClass().getSimpleName()).append((Object) ")]").toString();
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Rbm m2160clone() {
        return new Rbm(visible().copy(), connection().copy(), hidden().copy());
    }

    public Rbm reinitialize(RbmTrainingConfiguration rbmTrainingConfiguration) {
        return new Rbm(visible().reinitialize(rbmTrainingConfiguration.initialBiasScaling()), connection().reinitialize(rbmTrainingConfiguration.initialWeightScaling()), hidden().reinitialize(rbmTrainingConfiguration.initialBiasScaling()));
    }

    @Override // org.kramerlab.autoencoder.neuralnet.NeuralNet, scala.Function1
    /* renamed from: apply */
    public /* bridge */ /* synthetic */ Mat mo2126apply(Mat mat) {
        return mo2126apply(mat);
    }

    @Override // org.kramerlab.autoencoder.neuralnet.NeuralNet, org.kramerlab.autoencoder.neuralnet.NeuralNetLike
    /* renamed from: build, reason: avoid collision after fix types in other method */
    public /* bridge */ /* synthetic */ NeuralNet build2(List list) {
        return build((List<Layer>) list);
    }

    @Override // org.kramerlab.autoencoder.neuralnet.NeuralNet, org.kramerlab.autoencoder.neuralnet.NeuralNetLike
    public /* bridge */ /* synthetic */ NeuralNet build(List list) {
        return build((List<Layer>) list);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public Rbm(RbmLayer rbmLayer, FullBipartiteConnection fullBipartiteConnection, RbmLayer rbmLayer2) {
        super(List$.MODULE$.apply((Seq) Predef$.MODULE$.wrapRefArray(new Layer[]{rbmLayer, fullBipartiteConnection, rbmLayer2})));
        this.visible = rbmLayer;
        this.connection = fullBipartiteConnection;
        this.hidden = rbmLayer2;
    }
}
