public class ParameterServerTrainer extends DefaultTrainer
ParameterServerClient
we maintain updates for training a neural net.
Training happens relative to the mode of the remote ParameterServerNode| Modifier and Type | Class and Description |
|---|---|
static class |
ParameterServerTrainer.ParameterServerTrainerBuilder |
DefaultTrainer.DefaultTrainerBuilderonRootModel, originalModel, parallelWrapper, queue, queueMDS, replicatedModel, running, shouldStop, shouldUpdate, threadId, thrownException, useMDS, uuid| Constructor and Description |
|---|
ParameterServerTrainer() |
| Modifier and Type | Method and Description |
|---|---|
void |
feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet)
Train on a
DataSet |
void |
feedMultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Train on a
MultiDataSet |
Model |
getModel()
THe current model for the trainer
|
void |
updateModel(Model model)
Update the current
Model
for the worker |
cloneListener, configureListeners, isRunning, run, setupIfNeccessary, shutdown, waitTillRunningactiveCount, checkAccess, clone, countStackFrames, currentThread, destroy, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, stop, suspend, toString, yieldequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitsetUncaughtExceptionHandler, startpublic void feedMultiDataSet(@NonNull
org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
TrainerMultiDataSetfeedMultiDataSet in interface TrainerfeedMultiDataSet in class DefaultTrainerdataSet - the data set to train onpublic void feedDataSet(@NonNull
org.nd4j.linalg.dataset.api.DataSet dataSet)
TrainerDataSetfeedDataSet in interface TrainerfeedDataSet in class DefaultTrainerdataSet - the data set to train onpublic Model getModel()
TrainergetModel in interface TrainergetModel in class DefaultTrainerModel
for the workerpublic void updateModel(@NonNull
Model model)
TrainerModel
for the workerupdateModel in interface TrainerupdateModel in class DefaultTrainermodel - the new model for this worker