public class SparkDl4jLayer
extends java.lang.Object
implements java.io.Serializable
Constructor and Description |
---|
SparkDl4jLayer(org.apache.spark.api.java.JavaSparkContext sc,
NeuralNetConfiguration conf) |
SparkDl4jLayer(org.apache.spark.SparkContext sparkContext,
NeuralNetConfiguration conf) |
Modifier and Type | Method and Description |
---|---|
Layer |
fit(org.apache.spark.api.java.JavaSparkContext sc,
org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> rdd)
Fit the given rdd given the context.
|
Layer |
fit(java.lang.String path,
int labelIndex,
org.datavec.api.records.reader.RecordReader recordReader)
Fit the layer based on the specified org.deeplearning4j.spark context text file
|
Layer |
fitDataSet(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> rdd)
Fit a java rdd of dataset
|
org.apache.spark.mllib.linalg.Matrix |
predict(org.apache.spark.mllib.linalg.Matrix features)
Predict the given feature matrix
|
org.apache.spark.mllib.linalg.Vector |
predict(org.apache.spark.mllib.linalg.Vector point)
Predict the given vector
|
static Layer |
train(org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> data,
NeuralNetConfiguration conf)
Train a multi layer network
|
public SparkDl4jLayer(org.apache.spark.SparkContext sparkContext, NeuralNetConfiguration conf)
public SparkDl4jLayer(org.apache.spark.api.java.JavaSparkContext sc, NeuralNetConfiguration conf)
public Layer fit(java.lang.String path, int labelIndex, org.datavec.api.records.reader.RecordReader recordReader)
path
- the path to the text filelabelIndex
- the index of the labelrecordReader
- the record readerpublic Layer fit(org.apache.spark.api.java.JavaSparkContext sc, org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> rdd)
sc
- the org.deeplearning4j.spark contextrdd
- the rdd to fitDataSetpublic Layer fitDataSet(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> rdd)
rdd
- the rdd to fitpublic org.apache.spark.mllib.linalg.Matrix predict(org.apache.spark.mllib.linalg.Matrix features)
features
- the given feature matrixpublic org.apache.spark.mllib.linalg.Vector predict(org.apache.spark.mllib.linalg.Vector point)
point
- the vector to predictpublic static Layer train(org.apache.spark.api.java.JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> data, NeuralNetConfiguration conf)
data
- the data to train onconf
- the configuration of the network