public static class ParallelWrapper.Builder<T extends Model>
extends java.lang.Object
Modifier and Type | Field and Description |
---|---|
protected boolean |
averageUpdaters |
protected int |
averagingFrequency |
protected boolean |
isMQ |
protected boolean |
legacyAveraging |
protected T |
model |
protected int |
prefetchSize |
protected boolean |
reportScore |
protected TrainerContext |
trainerContext |
protected java.lang.Object[] |
trainerContextArgs |
protected int |
workers |
Constructor and Description |
---|
Builder(T model)
Build ParallelWrapper for MultiLayerNetwork
|
Modifier and Type | Method and Description |
---|---|
ParallelWrapper.Builder |
averageUpdaters(boolean reallyAverage)
This method enables/disables updaters averaging.
|
ParallelWrapper.Builder |
averagingFrequency(int freq)
Model averaging frequency.
|
ParallelWrapper |
build()
This method returns ParallelWrapper instance
|
ParallelWrapper.Builder |
prefetchBuffer(int size)
Size of prefetch buffer that will be used for background data prefetching.
|
ParallelWrapper.Builder |
reportScoreAfterAveraging(boolean reallyReport)
This method enables/disables averaged model score reporting
|
ParallelWrapper.Builder |
trainerContextArgs(java.lang.Object... trainerContextArgs)
Transer context args are for calling a
TrainerContext init method
when ParallelWrapper starts training |
ParallelWrapper.Builder |
trainerFactory(TrainerContext trainerContext)
Specify a
TrainerContext
for the given ParallelWrapper
instance. |
ParallelWrapper.Builder |
useLegacyAveraging(boolean reallyUse)
If set to true, legacy averaging method is used.
|
ParallelWrapper.Builder |
useMQ(boolean reallyUse)
This method enables/disable MagicQueue use
If set to true, all datasets will be spread among all available devices at prefetch phase using AsyncDataSetIterator
PLEASE NOTE: This is experimental feature.
|
ParallelWrapper.Builder |
workers(int num)
This method allows to configure number of workers that'll be used for parallel training
|
protected int workers
protected int prefetchSize
protected int averagingFrequency
protected boolean reportScore
protected boolean averageUpdaters
protected boolean legacyAveraging
protected boolean isMQ
protected TrainerContext trainerContext
protected java.lang.Object[] trainerContextArgs
public Builder(@NonNull T model)
model
- public ParallelWrapper.Builder trainerContextArgs(java.lang.Object... trainerContextArgs)
TrainerContext
init method
when ParallelWrapper
starts trainingtrainerContextArgs
- the args to use (maybe null)public ParallelWrapper.Builder trainerFactory(TrainerContext trainerContext)
TrainerContext
for the given ParallelWrapper
instance.
Defaults to DefaultTrainerContext
otherwisetrainerContext
- the trainer factory to usepublic ParallelWrapper.Builder workers(int num)
num
- public ParallelWrapper.Builder averagingFrequency(int freq)
freq
- number of iterations between averagingpublic ParallelWrapper.Builder averageUpdaters(boolean reallyAverage)
reallyAverage
- public ParallelWrapper.Builder useMQ(boolean reallyUse)
reallyUse
- public ParallelWrapper.Builder prefetchBuffer(int size)
size
- 0 to disable prefetching, any positive numberpublic ParallelWrapper.Builder useLegacyAveraging(boolean reallyUse)
reallyUse
- public ParallelWrapper.Builder reportScoreAfterAveraging(boolean reallyReport)
reallyReport
- public ParallelWrapper build()