public class SparkEarlyStoppingGraphTrainer extends BaseSparkEarlyStoppingTrainer<ComputationGraph>
Constructor and Description |
---|
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train,
EarlyStoppingListener<ComputationGraph> listener) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train,
int examplesPerFit,
int totalExamples) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train) |
SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc,
TrainingMaster trainingMaster,
EarlyStoppingConfiguration<ComputationGraph> esConfig,
ComputationGraph net,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train,
int examplesPerFit,
int totalExamples) |
Modifier and Type | Method and Description |
---|---|
protected void |
fit(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> data) |
protected void |
fitMulti(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> data) |
protected double |
getScore() |
fit, setListener
public SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train, int examplesPerFit, int totalExamples)
public SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train, int examplesPerFit, int totalExamples)
public SparkEarlyStoppingGraphTrainer(org.apache.spark.SparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train)
public SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train)
public SparkEarlyStoppingGraphTrainer(org.apache.spark.api.java.JavaSparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> train, EarlyStoppingListener<ComputationGraph> listener)
protected void fit(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> data)
fit
in class BaseSparkEarlyStoppingTrainer<ComputationGraph>
protected void fitMulti(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> data)
fitMulti
in class BaseSparkEarlyStoppingTrainer<ComputationGraph>
protected double getScore()
getScore
in class BaseSparkEarlyStoppingTrainer<ComputationGraph>