public class SparkDl4jLayer
extends java.lang.Object
implements java.io.Serializable
| Constructor and Description |
|---|
SparkDl4jLayer(org.apache.spark.api.java.JavaSparkContext sc,
NeuralNetConfiguration conf) |
SparkDl4jLayer(org.apache.spark.SparkContext sparkContext,
NeuralNetConfiguration conf) |
| Modifier and Type | Method and Description |
|---|---|
Layer |
fit(org.apache.spark.api.java.JavaSparkContext sc,
org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> rdd)
Fit the given rdd given the context.
|
Layer |
fit(java.lang.String path,
int labelIndex,
org.datavec.api.records.reader.RecordReader recordReader)
Fit the layer based on the specified org.deeplearning4j.spark context text file
|
Layer |
fitDataSet(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> rdd)
Fit a java rdd of dataset
|
org.apache.spark.mllib.linalg.Matrix |
predict(org.apache.spark.mllib.linalg.Matrix features)
Predict the given feature matrix
|
org.apache.spark.mllib.linalg.Vector |
predict(org.apache.spark.mllib.linalg.Vector point)
Predict the given vector
|
static Layer |
train(org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> data,
NeuralNetConfiguration conf)
Train a multi layer network
|
public SparkDl4jLayer(org.apache.spark.SparkContext sparkContext,
NeuralNetConfiguration conf)
public SparkDl4jLayer(org.apache.spark.api.java.JavaSparkContext sc,
NeuralNetConfiguration conf)
public Layer fit(java.lang.String path, int labelIndex, org.datavec.api.records.reader.RecordReader recordReader)
path - the path to the text filelabelIndex - the index of the labelrecordReader - the record readerpublic Layer fit(org.apache.spark.api.java.JavaSparkContext sc, org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> rdd)
sc - the org.deeplearning4j.spark contextrdd - the rdd to fitDataSetpublic Layer fitDataSet(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> rdd)
rdd - the rdd to fitpublic org.apache.spark.mllib.linalg.Matrix predict(org.apache.spark.mllib.linalg.Matrix features)
features - the given feature matrixpublic org.apache.spark.mllib.linalg.Vector predict(org.apache.spark.mllib.linalg.Vector point)
point - the vector to predictpublic static Layer train(org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> data, NeuralNetConfiguration conf)
data - the data to train onconf - the configuration of the network