public class ParameterServerTrainer extends DefaultTrainer
ParameterServerClient
we maintain updates for training a neural net.
Training happens relative to the mode of the remote ParameterServerNode
Modifier and Type | Class and Description |
---|---|
static class |
ParameterServerTrainer.ParameterServerTrainerBuilder |
DefaultTrainer.DefaultTrainerBuilder
onRootModel, originalModel, parallelWrapper, queue, queueMDS, replicatedModel, running, shouldStop, shouldUpdate, threadId, thrownException, useMDS, uuid
Constructor and Description |
---|
ParameterServerTrainer() |
Modifier and Type | Method and Description |
---|---|
void |
feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet)
Train on a
DataSet |
void |
feedMultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Train on a
MultiDataSet |
Model |
getModel()
THe current model for the trainer
|
void |
updateModel(Model model)
Update the current
Model
for the worker |
cloneListener, configureListeners, isRunning, run, setupIfNeccessary, shutdown, waitTillRunning
activeCount, checkAccess, clone, countStackFrames, currentThread, destroy, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, stop, suspend, toString, yield
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
setUncaughtExceptionHandler, start
public void feedMultiDataSet(@NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Trainer
MultiDataSet
feedMultiDataSet
in interface Trainer
feedMultiDataSet
in class DefaultTrainer
dataSet
- the data set to train onpublic void feedDataSet(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet)
Trainer
DataSet
feedDataSet
in interface Trainer
feedDataSet
in class DefaultTrainer
dataSet
- the data set to train onpublic Model getModel()
Trainer
getModel
in interface Trainer
getModel
in class DefaultTrainer
Model
for the workerpublic void updateModel(@NonNull Model model)
Trainer
Model
for the workerupdateModel
in interface Trainer
updateModel
in class DefaultTrainer
model
- the new model for this worker