public class ParameterAveragingTrainingMaster extends java.lang.Object implements TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
TrainingMaster
implementation for training networks on Spark.
This is standard parameter averaging with a
configurable averaging period.Modifier and Type | Class and Description |
---|---|
static class |
ParameterAveragingTrainingMaster.Builder |
Constructor and Description |
---|
ParameterAveragingTrainingMaster(boolean saveUpdater,
java.lang.Integer numWorkers,
int rddDataSetNumExamples,
int batchSizePerWorker,
int averagingFrequency,
int prefetchNumBatches) |
ParameterAveragingTrainingMaster(boolean saveUpdater,
java.lang.Integer numWorkers,
int rddDataSetNumExamples,
int batchSizePerWorker,
int averagingFrequency,
int prefetchNumBatches,
Repartition repartition,
RepartitionStrategy repartitionStrategy,
boolean collectTrainingStats) |
ParameterAveragingTrainingMaster(boolean saveUpdater,
java.lang.Integer numWorkers,
int rddDataSetNumExamples,
int batchSizePerWorker,
int averagingFrequency,
int prefetchNumBatches,
Repartition repartition,
RepartitionStrategy repartitionStrategy,
org.apache.spark.storage.StorageLevel storageLevel,
boolean collectTrainingStats) |
Modifier and Type | Method and Description |
---|---|
void |
addHook(TrainingHook trainingHook)
Add a hook for the master for pre and post training
|
boolean |
deleteTempFiles(org.apache.spark.api.java.JavaSparkContext sc)
Attempt to delete any temporary files generated by this TrainingMaster.
|
boolean |
deleteTempFiles(org.apache.spark.SparkContext sc)
Attempt to delete any temporary files generated by this TrainingMaster.
|
void |
executeTraining(SparkComputationGraph graph,
org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
Train the SparkComputationGraph with the specified serialized DataSet objects.
|
void |
executeTraining(SparkComputationGraph graph,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
Train the SparkComputationGraph with the specified data set
|
void |
executeTraining(SparkDl4jMultiLayer network,
org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
Deprecated.
Due to poor performance
|
void |
executeTraining(SparkDl4jMultiLayer network,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
Train the SparkDl4jMultiLayer with the specified data set
|
void |
executeTrainingMDS(SparkComputationGraph graph,
org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
Train the SparkComputationGraph with the specified serialized MultiDataSet objects.
|
void |
executeTrainingMDS(SparkComputationGraph graph,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> trainingData)
Train the SparkComputationGraph with the specified data set
|
void |
executeTrainingPaths(SparkComputationGraph network,
org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
EXPERIMENTAL method, may be removed in a future release.
Fit the network using a list of paths for serialized DataSet objects. |
void |
executeTrainingPaths(SparkDl4jMultiLayer network,
org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
EXPERIMENTAL method, may be removed in a future release.
Fit the network using a list of paths for serialized DataSet objects. |
void |
executeTrainingPathsMDS(SparkComputationGraph network,
org.apache.spark.api.java.JavaRDD<java.lang.String> trainingMultiDataPaths)
EXPERIMENTAL method, may be removed in a future release.
Fit the network using a list of paths for serialized MultiDataSet objects. |
static ParameterAveragingTrainingMaster |
fromJson(java.lang.String jsonStr)
Create a ParameterAveragingTrainingMaster instance by deserializing a JSON string that has been serialized with
toJson() |
static ParameterAveragingTrainingMaster |
fromYaml(java.lang.String yamlStr)
Create a ParameterAveragingTrainingMaster instance by deserializing a YAML string that has been serialized with
toYaml() |
boolean |
getIsCollectTrainingStats()
Get the current setting for collectTrainingStats
|
SparkTrainingStats |
getTrainingStats()
Return the training statistics.
|
ParameterAveragingTrainingWorker |
getWorkerInstance(SparkComputationGraph graph)
Get the worker instance for this training master
|
ParameterAveragingTrainingWorker |
getWorkerInstance(SparkDl4jMultiLayer network)
Get the worker instance for this training master
|
void |
removeHook(TrainingHook trainingHook)
Remove a training hook from the worker
|
void |
setCollectTrainingStats(boolean collectTrainingStats)
Set whether the training statistics should be collected.
|
void |
setListeners(java.util.Collection<IterationListener> listeners)
Set the iteration listeners.
|
void |
setListeners(StatsStorageRouter statsStorage,
java.util.Collection<IterationListener> listeners)
Set the iteration listeners and the StatsStorageRouter.
|
java.lang.String |
toJson()
Get the TrainingMaster configuration as JSON
|
java.lang.String |
toYaml()
Get the TrainingMaster configuration as YAML
|
public ParameterAveragingTrainingMaster(boolean saveUpdater, java.lang.Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches)
public ParameterAveragingTrainingMaster(boolean saveUpdater, java.lang.Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches, Repartition repartition, RepartitionStrategy repartitionStrategy, boolean collectTrainingStats)
saveUpdater
- If true: save (and average) the updater state when doing parameter averagingnumWorkers
- Number of workers (executors * threads per executor) for the clusterrddDataSetNumExamples
- Number of examples in each DataSet object in the RDD<DataSet>
batchSizePerWorker
- Number of examples to use per worker per fitaveragingFrequency
- Frequency (in number of minibatches) with which to average parametersprefetchNumBatches
- Number of batches to asynchronously prefetch (0: disable)collectTrainingStats
- If true: collect training statistics for debugging/optimization purposespublic ParameterAveragingTrainingMaster(boolean saveUpdater, java.lang.Integer numWorkers, int rddDataSetNumExamples, int batchSizePerWorker, int averagingFrequency, int prefetchNumBatches, Repartition repartition, RepartitionStrategy repartitionStrategy, org.apache.spark.storage.StorageLevel storageLevel, boolean collectTrainingStats)
public void removeHook(TrainingHook trainingHook)
removeHook
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
trainingHook
- the training hook to removepublic void addHook(TrainingHook trainingHook)
addHook
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
trainingHook
- the training hook to addpublic java.lang.String toJson()
TrainingMaster
toJson
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
public java.lang.String toYaml()
TrainingMaster
toYaml
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
public static ParameterAveragingTrainingMaster fromJson(java.lang.String jsonStr)
toJson()
jsonStr
- ParameterAveragingTrainingMaster configuration serialized as JSONpublic static ParameterAveragingTrainingMaster fromYaml(java.lang.String yamlStr)
toYaml()
yamlStr
- ParameterAveragingTrainingMaster configuration serialized as YAMLpublic ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network)
TrainingMaster
getWorkerInstance
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
network
- Current SparkDl4jMultiLayerpublic ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph)
TrainingMaster
getWorkerInstance
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
graph
- Current SparkComputationGraphpublic void executeTraining(SparkDl4jMultiLayer network, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
TrainingMaster
executeTraining
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
network
- Current network statetrainingData
- Data to train on@Deprecated public void executeTraining(SparkDl4jMultiLayer network, org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
TrainingMaster
executeTraining
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
network
- Current network statetrainingData
- Data to train onpublic void executeTrainingPaths(SparkDl4jMultiLayer network, org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
TrainingMaster
executeTrainingPaths
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
network
- Current network statetrainingDataPaths
- Data to train onpublic void executeTraining(SparkComputationGraph graph, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> trainingData)
TrainingMaster
executeTraining
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
graph
- Current network statetrainingData
- Data to train onpublic void executeTrainingMDS(SparkComputationGraph graph, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> trainingData)
TrainingMaster
executeTrainingMDS
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
graph
- Current network statetrainingData
- Data to train onpublic void executeTraining(SparkComputationGraph graph, org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
TrainingMaster
DataSet.save(OutputStream)
executeTraining
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
graph
- Current network statetrainingData
- Data to train onpublic void executeTrainingMDS(SparkComputationGraph graph, org.apache.spark.api.java.JavaPairRDD<java.lang.String,org.apache.spark.input.PortableDataStream> trainingData)
TrainingMaster
executeTrainingMDS
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
graph
- Current network statetrainingData
- Data to train onpublic void executeTrainingPaths(SparkComputationGraph network, org.apache.spark.api.java.JavaRDD<java.lang.String> trainingDataPaths)
TrainingMaster
executeTrainingPaths
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
network
- Current network statetrainingDataPaths
- Data to train onpublic void executeTrainingPathsMDS(SparkComputationGraph network, org.apache.spark.api.java.JavaRDD<java.lang.String> trainingMultiDataPaths)
TrainingMaster
executeTrainingPathsMDS
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
network
- Current network statetrainingMultiDataPaths
- Data to train onpublic void setCollectTrainingStats(boolean collectTrainingStats)
TrainingMaster
These statistics are primarily used for debugging and optimization, in order to gain some insight into what aspects of network training are taking the most time.
setCollectTrainingStats
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
collectTrainingStats
- If true: collecting training statistics will bepublic boolean getIsCollectTrainingStats()
TrainingMaster
getIsCollectTrainingStats
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
public SparkTrainingStats getTrainingStats()
TrainingMaster
getTrainingStats
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
public void setListeners(java.util.Collection<IterationListener> listeners)
TrainingMaster
setListeners
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
listeners
- Listeners to setpublic void setListeners(StatsStorageRouter statsStorage, java.util.Collection<IterationListener> listeners)
TrainingMaster
setListeners
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
statsStorage
- StatsStorageRouter in which to place the resultslisteners
- Listenerspublic boolean deleteTempFiles(org.apache.spark.api.java.JavaSparkContext sc)
TrainingMaster
deleteTempFiles
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
sc
- JavaSparkContext (used to access HDFS etc file systems, when required)public boolean deleteTempFiles(org.apache.spark.SparkContext sc)
TrainingMaster
deleteTempFiles
in interface TrainingMaster<ParameterAveragingTrainingResult,ParameterAveragingTrainingWorker>
sc
- SparkContext (used to access HDFS etc file systems, when required)