public class ParameterAveragingTrainingWorker extends java.lang.Object implements TrainingWorker<ParameterAveragingTrainingResult>
Constructor and Description |
---|
ParameterAveragingTrainingWorker(org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcast,
boolean saveUpdater,
WorkerConfiguration configuration,
java.util.Collection<TrainingHook> trainingHooks,
java.util.Collection<IterationListener> listeners,
StatsStorageRouterProvider routerProvider) |
public ParameterAveragingTrainingWorker(org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcast, boolean saveUpdater, WorkerConfiguration configuration, java.util.Collection<TrainingHook> trainingHooks, java.util.Collection<IterationListener> listeners, StatsStorageRouterProvider routerProvider)
public void removeHook(TrainingHook trainingHook)
removeHook
in interface TrainingWorker<ParameterAveragingTrainingResult>
trainingHook
- the training hook to removepublic void addHook(TrainingHook trainingHook)
addHook
in interface TrainingWorker<ParameterAveragingTrainingResult>
trainingHook
- the training hook to addpublic MultiLayerNetwork getInitialModel()
TrainingWorker
getInitialModel
in interface TrainingWorker<ParameterAveragingTrainingResult>
public ComputationGraph getInitialModelGraph()
TrainingWorker
getInitialModelGraph
in interface TrainingWorker<ParameterAveragingTrainingResult>
public ParameterAveragingTrainingResult processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorker
processMinibatch
in interface TrainingWorker<ParameterAveragingTrainingResult>
dataSet
- Data set to train onnetwork
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic ParameterAveragingTrainingResult processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
processMinibatch
in interface TrainingWorker<ParameterAveragingTrainingResult>
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic ParameterAveragingTrainingResult processMinibatch(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
processMinibatch
in interface TrainingWorker<ParameterAveragingTrainingResult>
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorpublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(DataSet, MultiLayerNetwork, boolean)
but used when SparkTrainingStats
are being collecteprocessMinibatchWithStats
in interface TrainingWorker<ParameterAveragingTrainingResult>
public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(DataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedprocessMinibatchWithStats
in interface TrainingWorker<ParameterAveragingTrainingResult>
public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorker
TrainingWorker.processMinibatch(MultiDataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedprocessMinibatchWithStats
in interface TrainingWorker<ParameterAveragingTrainingResult>
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network)
TrainingWorker
getFinalResult
in interface TrainingWorker<ParameterAveragingTrainingResult>
network
- Current state of the networkpublic ParameterAveragingTrainingResult getFinalResult(ComputationGraph network)
TrainingWorker
getFinalResult
in interface TrainingWorker<ParameterAveragingTrainingResult>
network
- Current state of the networkpublic ParameterAveragingTrainingResult getFinalResultNoData()
TrainingWorker
getFinalResultNoData
in interface TrainingWorker<ParameterAveragingTrainingResult>
public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultNoDataWithStats()
TrainingWorker
TrainingWorker.getFinalResultNoData()
but used when SparkTrainingStats
are being collectedgetFinalResultNoDataWithStats
in interface TrainingWorker<ParameterAveragingTrainingResult>
public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network)
TrainingWorker
TrainingWorker.getFinalResult(MultiLayerNetwork)
but used when SparkTrainingStats
are being collectedgetFinalResultWithStats
in interface TrainingWorker<ParameterAveragingTrainingResult>
public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph)
TrainingWorker
TrainingWorker.getFinalResult(ComputationGraph)
but used when SparkTrainingStats
are being collectedgetFinalResultWithStats
in interface TrainingWorker<ParameterAveragingTrainingResult>
public WorkerConfiguration getDataConfiguration()
TrainingWorker
WorkerConfiguration
that contains information such as minibatch sizes, etcgetDataConfiguration
in interface TrainingWorker<ParameterAveragingTrainingResult>