public class ParameterAveragingTrainingWorker extends java.lang.Object implements TrainingWorker<ParameterAveragingTrainingResult>
| Constructor and Description |
|---|
ParameterAveragingTrainingWorker(org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcast,
boolean saveUpdater,
WorkerConfiguration configuration,
java.util.Collection<TrainingHook> trainingHooks,
java.util.Collection<IterationListener> listeners,
StatsStorageRouterProvider routerProvider) |
public ParameterAveragingTrainingWorker(org.apache.spark.broadcast.Broadcast<NetBroadcastTuple> broadcast, boolean saveUpdater, WorkerConfiguration configuration, java.util.Collection<TrainingHook> trainingHooks, java.util.Collection<IterationListener> listeners, StatsStorageRouterProvider routerProvider)
public void removeHook(TrainingHook trainingHook)
removeHook in interface TrainingWorker<ParameterAveragingTrainingResult>trainingHook - the training hook to removepublic void addHook(TrainingHook trainingHook)
addHook in interface TrainingWorker<ParameterAveragingTrainingResult>trainingHook - the training hook to addpublic MultiLayerNetwork getInitialModel()
TrainingWorkergetInitialModel in interface TrainingWorker<ParameterAveragingTrainingResult>public ComputationGraph getInitialModelGraph()
TrainingWorkergetInitialModelGraph in interface TrainingWorker<ParameterAveragingTrainingResult>public ParameterAveragingTrainingResult processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorkerprocessMinibatch in interface TrainingWorker<ParameterAveragingTrainingResult>dataSet - Data set to train onnetwork - Network to trainisLast - If true: last data set currently available. If false: more data sets will be processed for this executorpublic ParameterAveragingTrainingResult processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorkerprocessMinibatch in interface TrainingWorker<ParameterAveragingTrainingResult>dataSet - Data set to train ongraph - Network to trainisLast - If true: last data set currently available. If false: more data sets will be processed for this executorpublic ParameterAveragingTrainingResult processMinibatch(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorkerprocessMinibatch in interface TrainingWorker<ParameterAveragingTrainingResult>dataSet - Data set to train ongraph - Network to trainisLast - If true: last data set currently available. If false: more data sets will be processed for this executorpublic Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet, MultiLayerNetwork network, boolean isLast)
TrainingWorkerTrainingWorker.processMinibatch(DataSet, MultiLayerNetwork, boolean) but used when SparkTrainingStats are being collecteprocessMinibatchWithStats in interface TrainingWorker<ParameterAveragingTrainingResult>public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorkerTrainingWorker.processMinibatch(DataSet, ComputationGraph, boolean) but used when SparkTrainingStats are being collectedprocessMinibatchWithStats in interface TrainingWorker<ParameterAveragingTrainingResult>public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
TrainingWorkerTrainingWorker.processMinibatch(MultiDataSet, ComputationGraph, boolean) but used when SparkTrainingStats are being collectedprocessMinibatchWithStats in interface TrainingWorker<ParameterAveragingTrainingResult>public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network)
TrainingWorkergetFinalResult in interface TrainingWorker<ParameterAveragingTrainingResult>network - Current state of the networkpublic ParameterAveragingTrainingResult getFinalResult(ComputationGraph network)
TrainingWorkergetFinalResult in interface TrainingWorker<ParameterAveragingTrainingResult>network - Current state of the networkpublic ParameterAveragingTrainingResult getFinalResultNoData()
TrainingWorkergetFinalResultNoData in interface TrainingWorker<ParameterAveragingTrainingResult>public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultNoDataWithStats()
TrainingWorkerTrainingWorker.getFinalResultNoData() but used when SparkTrainingStats are being collectedgetFinalResultNoDataWithStats in interface TrainingWorker<ParameterAveragingTrainingResult>public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network)
TrainingWorkerTrainingWorker.getFinalResult(MultiLayerNetwork) but used when SparkTrainingStats are being collectedgetFinalResultWithStats in interface TrainingWorker<ParameterAveragingTrainingResult>public Pair<ParameterAveragingTrainingResult,SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph)
TrainingWorkerTrainingWorker.getFinalResult(ComputationGraph) but used when SparkTrainingStats are being collectedgetFinalResultWithStats in interface TrainingWorker<ParameterAveragingTrainingResult>public WorkerConfiguration getDataConfiguration()
TrainingWorkerWorkerConfiguration that contains information such as minibatch sizes, etcgetDataConfiguration in interface TrainingWorker<ParameterAveragingTrainingResult>