public interface TrainingWorker<R extends TrainingResult>
extends java.io.Serializable
TrainingMaster
for processing.TrainingWorker implementations provide a layer of abstraction for network learning tha should allow for more flexibility/ control over how learning is conducted (including for example asynchronous communication)
Modifier and Type | Method and Description |
---|---|
void |
addHook(TrainingHook trainingHook)
Add a training hook to be used
during training of the worker
|
WorkerConfiguration |
getDataConfiguration()
Get the
WorkerConfiguration that contains information such as minibatch sizes, etc |
R |
getFinalResult(ComputationGraph graph)
Get the final result to be returned to the driver
|
R |
getFinalResult(MultiLayerNetwork network)
Get the final result to be returned to the driver
|
R |
getFinalResultNoData()
Get the final result to be returned to the driver, if no data was available for this executor
|
Pair<R,SparkTrainingStats> |
getFinalResultNoDataWithStats()
As per
getFinalResultNoData() but used when SparkTrainingStats are being collected |
Pair<R,SparkTrainingStats> |
getFinalResultWithStats(ComputationGraph graph)
As per
getFinalResult(ComputationGraph) but used when SparkTrainingStats are being collected |
Pair<R,SparkTrainingStats> |
getFinalResultWithStats(MultiLayerNetwork network)
As per
getFinalResult(MultiLayerNetwork) but used when SparkTrainingStats are being collected |
MultiLayerNetwork |
getInitialModel()
Get the initial model when training a MultiLayerNetwork/SparkDl4jMultiLayer
|
ComputationGraph |
getInitialModelGraph()
Get the initial model when training a ComputationGraph/SparkComputationGraph
|
R |
processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet,
ComputationGraph graph,
boolean isLast)
Process (fit) a minibatch for a ComputationGraph
|
R |
processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet,
MultiLayerNetwork network,
boolean isLast)
Process (fit) a minibatch for a MultiLayerNetwork
|
R |
processMinibatch(org.nd4j.linalg.dataset.api.MultiDataSet dataSet,
ComputationGraph graph,
boolean isLast)
Process (fit) a minibatch for a ComputationGraph using a MultiDataSet
|
Pair<R,SparkTrainingStats> |
processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet,
ComputationGraph graph,
boolean isLast)
As per
processMinibatch(DataSet, ComputationGraph, boolean) but used when SparkTrainingStats are being collected |
Pair<R,SparkTrainingStats> |
processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet,
MultiLayerNetwork network,
boolean isLast)
As per
processMinibatch(DataSet, MultiLayerNetwork, boolean) but used when SparkTrainingStats are being collecte |
Pair<R,SparkTrainingStats> |
processMinibatchWithStats(org.nd4j.linalg.dataset.api.MultiDataSet dataSet,
ComputationGraph graph,
boolean isLast)
As per
processMinibatch(MultiDataSet, ComputationGraph, boolean) but used when SparkTrainingStats are being collected |
void |
removeHook(TrainingHook trainingHook)
Remove a training hook from the worker
|
void removeHook(TrainingHook trainingHook)
trainingHook
- the training hook to removevoid addHook(TrainingHook trainingHook)
trainingHook
- the training hook to addMultiLayerNetwork getInitialModel()
ComputationGraph getInitialModelGraph()
R processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet, MultiLayerNetwork network, boolean isLast)
dataSet
- Data set to train onnetwork
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorR processMinibatch(org.nd4j.linalg.dataset.api.DataSet dataSet, ComputationGraph graph, boolean isLast)
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorR processMinibatch(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
dataSet
- Data set to train ongraph
- Network to trainisLast
- If true: last data set currently available. If false: more data sets will be processed for this executorPair<R,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet, MultiLayerNetwork network, boolean isLast)
processMinibatch(DataSet, MultiLayerNetwork, boolean)
but used when SparkTrainingStats
are being collectePair<R,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.DataSet dataSet, ComputationGraph graph, boolean isLast)
processMinibatch(DataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedPair<R,SparkTrainingStats> processMinibatchWithStats(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, ComputationGraph graph, boolean isLast)
processMinibatch(MultiDataSet, ComputationGraph, boolean)
but used when SparkTrainingStats
are being collectedR getFinalResult(MultiLayerNetwork network)
network
- Current state of the networkR getFinalResult(ComputationGraph graph)
graph
- Current state of the networkR getFinalResultNoData()
Pair<R,SparkTrainingStats> getFinalResultNoDataWithStats()
getFinalResultNoData()
but used when SparkTrainingStats
are being collectedPair<R,SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network)
getFinalResult(MultiLayerNetwork)
but used when SparkTrainingStats
are being collectedPair<R,SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph)
getFinalResult(ComputationGraph)
but used when SparkTrainingStats
are being collectedWorkerConfiguration getDataConfiguration()
WorkerConfiguration
that contains information such as minibatch sizes, etc