public interface TrainingMaster<R extends TrainingResult,W extends TrainingWorker<R>>
SparkDl4jMultiLayer
and SparkComputationGraph
to be used with different training methods.Modifier and Type | Method and Description |
---|---|
void |
addHook(TrainingHook trainingHook)
Add a hook for the master for pre and post training
|
boolean |
deleteTempFiles(org.apache.spark.api.java.JavaSparkContext sc)
Attempt to delete any temporary files generated by this TrainingMaster.
|
boolean |
deleteTempFiles(org.apache.spark.SparkContext sc)
Attempt to delete any temporary files generated by this TrainingMaster.
|
void |
executeTraining(SparkComputationGraph network,
org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
Deprecated.
Deprecated due to poor performance
|
void |
executeTraining(SparkComputationGraph graph,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
Train the SparkComputationGraph with the specified data set
|
void |
executeTraining(SparkDl4jMultiLayer network,
org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
Deprecated.
Deprecated due to poor performance
|
void |
executeTraining(SparkDl4jMultiLayer network,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
Train the SparkDl4jMultiLayer with the specified data set
|
void |
executeTrainingMDS(SparkComputationGraph network,
org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
Deprecated.
Deprecated due to poor performance
|
void |
executeTrainingMDS(SparkComputationGraph graph,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> trainingData)
Train the SparkComputationGraph with the specified data set
|
void |
executeTrainingPaths(SparkComputationGraph network,
org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
EXPERIMENTAL method, may be removed in a future release.
Fit the network using a list of paths for serialized DataSet objects. |
void |
executeTrainingPaths(SparkDl4jMultiLayer network,
org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
EXPERIMENTAL method, may be removed in a future release.
Fit the network using a list of paths for serialized DataSet objects. |
void |
executeTrainingPathsMDS(SparkComputationGraph network,
org.apache.spark.api.java.JavaRDD<java.lang.String> trainingMultiDataSetPaths)
EXPERIMENTAL method, may be removed in a future release.
Fit the network using a list of paths for serialized MultiDataSet objects. |
boolean |
getIsCollectTrainingStats()
Get the current setting for collectTrainingStats
|
SparkTrainingStats |
getTrainingStats()
Return the training statistics.
|
W |
getWorkerInstance(SparkComputationGraph graph)
Get the worker instance for this training master
|
W |
getWorkerInstance(SparkDl4jMultiLayer network)
Get the worker instance for this training master
|
void |
removeHook(TrainingHook trainingHook)
Remove a training hook from the worker
|
void |
setCollectTrainingStats(boolean collectTrainingStats)
Set whether the training statistics should be collected.
|
void |
setListeners(java.util.Collection<IterationListener> listeners)
Set the iteration listeners.
|
void |
setListeners(StatsStorageRouter router,
java.util.Collection<IterationListener> listeners)
Set the iteration listeners and the StatsStorageRouter.
|
java.lang.String |
toJson()
Get the TrainingMaster configuration as JSON
|
java.lang.String |
toYaml()
Get the TrainingMaster configuration as YAML
|
void removeHook(TrainingHook trainingHook)
trainingHook
- the training hook to removevoid addHook(TrainingHook trainingHook)
trainingHook
- the training hook to addjava.lang.String toJson()
java.lang.String toYaml()
W getWorkerInstance(SparkDl4jMultiLayer network)
network
- Current SparkDl4jMultiLayerW getWorkerInstance(SparkComputationGraph graph)
graph
- Current SparkComputationGraphvoid executeTraining(SparkDl4jMultiLayer network, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
network
- Current network statetrainingData
- Data to train on@Deprecated void executeTraining(SparkDl4jMultiLayer network, org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
network
- Current network statetrainingData
- Data to train onvoid executeTrainingPaths(SparkDl4jMultiLayer network, org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
network
- Current network statetrainingDataPaths
- Data to train onvoid executeTraining(SparkComputationGraph graph, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
graph
- Current network statetrainingData
- Data to train on@Deprecated void executeTraining(SparkComputationGraph network, org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
DataSet.save(OutputStream)
network
- Current network statetrainingData
- Data to train onvoid executeTrainingPaths(SparkComputationGraph network, org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
network
- Current network statetrainingDataPaths
- Data to train onvoid executeTrainingPathsMDS(SparkComputationGraph network, org.apache.spark.api.java.JavaRDD<java.lang.String> trainingMultiDataSetPaths)
network
- Current network statetrainingMultiDataSetPaths
- Data to train onvoid executeTrainingMDS(SparkComputationGraph graph, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> trainingData)
graph
- Current network statetrainingData
- Data to train on@Deprecated void executeTrainingMDS(SparkComputationGraph network, org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
network
- Current network statetrainingData
- Data to train onvoid setCollectTrainingStats(boolean collectTrainingStats)
These statistics are primarily used for debugging and optimization, in order to gain some insight into what aspects of network training are taking the most time.
collectTrainingStats
- If true: collecting training statistics will beboolean getIsCollectTrainingStats()
SparkTrainingStats getTrainingStats()
void setListeners(java.util.Collection<IterationListener> listeners)
listeners
- Listeners to setvoid setListeners(StatsStorageRouter router, java.util.Collection<IterationListener> listeners)
router
- StatsStorageRouter in which to place the resultslisteners
- Listenersboolean deleteTempFiles(org.apache.spark.api.java.JavaSparkContext sc)
sc
- JavaSparkContext (used to access HDFS etc file systems, when required)boolean deleteTempFiles(org.apache.spark.SparkContext sc)
sc
- SparkContext (used to access HDFS etc file systems, when required)