public class DefaultTrainer extends java.lang.Thread implements Trainer
Modifier and Type | Class and Description |
---|---|
static class |
DefaultTrainer.DefaultTrainerBuilder |
Modifier and Type | Field and Description |
---|---|
protected boolean |
onRootModel |
protected Model |
originalModel |
protected ParallelWrapper |
parallelWrapper |
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> |
queue |
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> |
queueMDS |
protected Model |
replicatedModel |
protected java.util.concurrent.atomic.AtomicInteger |
running |
protected java.util.concurrent.atomic.AtomicBoolean |
shouldStop |
protected java.util.concurrent.atomic.AtomicBoolean |
shouldUpdate |
protected int |
threadId |
protected java.lang.Exception |
thrownException |
protected boolean |
useMDS |
protected java.lang.String |
uuid |
Constructor and Description |
---|
DefaultTrainer() |
Modifier and Type | Method and Description |
---|---|
protected static IterationListener |
cloneListener(IterationListener original) |
protected void |
configureListeners(java.lang.String workerUUID,
java.util.Collection<IterationListener> oldListeners,
java.util.Collection<IterationListener> replicatedListeners) |
void |
feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet)
Train on a
DataSet |
void |
feedMultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Train on a
MultiDataSet |
Model |
getModel()
THe current model for the trainer
|
boolean |
isRunning() |
void |
run() |
protected void |
setupIfNeccessary() |
void |
shutdown()
Shutdown this worker
|
void |
updateModel(Model model)
Update the current
Model
for the worker |
void |
waitTillRunning()
Block the main thread
till the trainer is up and running.
|
activeCount, checkAccess, clone, countStackFrames, currentThread, destroy, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, stop, suspend, toString, yield
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
setUncaughtExceptionHandler, start
protected Model originalModel
protected Model replicatedModel
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue
protected java.util.concurrent.LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queueMDS
protected java.util.concurrent.atomic.AtomicInteger running
protected int threadId
protected java.util.concurrent.atomic.AtomicBoolean shouldUpdate
protected java.util.concurrent.atomic.AtomicBoolean shouldStop
protected java.lang.Exception thrownException
protected volatile boolean useMDS
protected final java.lang.String uuid
protected boolean onRootModel
protected ParallelWrapper parallelWrapper
public void feedMultiDataSet(@NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Trainer
MultiDataSet
feedMultiDataSet
in interface Trainer
dataSet
- the data set to train onpublic void feedDataSet(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet)
Trainer
DataSet
feedDataSet
in interface Trainer
dataSet
- the data set to train onpublic Model getModel()
Trainer
public void updateModel(@NonNull Model model)
Trainer
Model
for the workerupdateModel
in interface Trainer
model
- the new model for this workerprotected void setupIfNeccessary()
public void shutdown()
Trainer
public void run()
run
in interface java.lang.Runnable
run
in class java.lang.Thread
public void waitTillRunning()
Trainer
waitTillRunning
in interface Trainer
protected static IterationListener cloneListener(IterationListener original)
protected void configureListeners(java.lang.String workerUUID, java.util.Collection<IterationListener> oldListeners, java.util.Collection<IterationListener> replicatedListeners)