public class DefaultTrainer extends java.lang.Thread implements Trainer
| Modifier and Type | Class and Description |
|---|---|
static class |
DefaultTrainer.DefaultTrainerBuilder |
| Modifier and Type | Field and Description |
|---|---|
protected boolean |
onRootModel |
protected Model |
originalModel |
protected ParallelWrapper |
parallelWrapper |
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> |
queue |
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> |
queueMDS |
protected Model |
replicatedModel |
protected java.util.concurrent.atomic.AtomicInteger |
running |
protected java.util.concurrent.atomic.AtomicBoolean |
shouldStop |
protected java.util.concurrent.atomic.AtomicBoolean |
shouldUpdate |
protected int |
threadId |
protected java.lang.Exception |
thrownException |
protected boolean |
useMDS |
protected java.lang.String |
uuid |
| Constructor and Description |
|---|
DefaultTrainer() |
| Modifier and Type | Method and Description |
|---|---|
protected static IterationListener |
cloneListener(IterationListener original) |
protected void |
configureListeners(java.lang.String workerUUID,
java.util.Collection<IterationListener> oldListeners,
java.util.Collection<IterationListener> replicatedListeners) |
void |
feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet)
Train on a
DataSet |
void |
feedMultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Train on a
MultiDataSet |
Model |
getModel()
THe current model for the trainer
|
boolean |
isRunning() |
void |
run() |
protected void |
setupIfNeccessary() |
void |
shutdown()
Shutdown this worker
|
void |
updateModel(Model model)
Update the current
Model
for the worker |
void |
waitTillRunning()
Block the main thread
till the trainer is up and running.
|
activeCount, checkAccess, clone, countStackFrames, currentThread, destroy, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, stop, suspend, toString, yieldequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitsetUncaughtExceptionHandler, startprotected Model originalModel
protected Model replicatedModel
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queueMDS
protected java.util.concurrent.atomic.AtomicInteger running
protected int threadId
protected java.util.concurrent.atomic.AtomicBoolean shouldUpdate
protected java.util.concurrent.atomic.AtomicBoolean shouldStop
protected java.lang.Exception thrownException
protected volatile boolean useMDS
protected final java.lang.String uuid
protected boolean onRootModel
protected ParallelWrapper parallelWrapper
public void feedMultiDataSet(@NonNull
org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
TrainerMultiDataSetfeedMultiDataSet in interface TrainerdataSet - the data set to train onpublic void feedDataSet(@NonNull
org.nd4j.linalg.dataset.api.DataSet dataSet)
TrainerDataSetfeedDataSet in interface TrainerdataSet - the data set to train onpublic Model getModel()
Trainerpublic void updateModel(@NonNull
Model model)
TrainerModel
for the workerupdateModel in interface Trainermodel - the new model for this workerprotected void setupIfNeccessary()
public void shutdown()
Trainerpublic void run()
run in interface java.lang.Runnablerun in class java.lang.Threadpublic void waitTillRunning()
TrainerwaitTillRunning in interface Trainerprotected static IterationListener cloneListener(IterationListener original)
protected void configureListeners(java.lang.String workerUUID,
java.util.Collection<IterationListener> oldListeners,
java.util.Collection<IterationListener> replicatedListeners)