public abstract class BaseStatsListener extends java.lang.Object implements RoutingIterationListener
Modifier and Type | Field and Description |
---|---|
static java.lang.String |
TYPE_ID |
Constructor and Description |
---|
BaseStatsListener(StatsStorageRouter router)
Create a StatsListener with network information collected at every iteration.
|
BaseStatsListener(StatsStorageRouter router,
int listenerFrequency)
Create a StatsListener with network information collected every n >= 1 time steps
|
BaseStatsListener(StatsStorageRouter router,
StatsInitializationConfiguration initConfig,
StatsUpdateConfiguration updateConfig,
java.lang.String sessionID,
java.lang.String workerID) |
Modifier and Type | Method and Description |
---|---|
abstract BaseStatsListener |
clone() |
StatsInitializationConfiguration |
getInitConfig() |
abstract StatsInitializationReport |
getNewInitializationReport() |
abstract StatsReport |
getNewStatsReport() |
abstract StorageMetaData |
getNewStorageMetaData(long initTime,
java.lang.String sessionID,
java.lang.String workerID) |
java.lang.String |
getSessionID() |
StatsStorageRouter |
getStorageRouter() |
StatsUpdateConfiguration |
getUpdateConfig() |
java.lang.String |
getWorkerID() |
void |
invoke()
Change invoke to true
|
boolean |
invoked()
Get if listener invoked
|
void |
iterationDone(Model model,
int iteration)
Event listener for each iteration
|
void |
onBackwardPass(Model model)
Called once per iteration (backward pass) after gradients have been calculated, and updated
Gradients are available via
Model.gradient() . |
void |
onEpochEnd(Model model)
Called once at the end of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator) ,
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
void |
onEpochStart(Model model)
Called once at the start of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator) ,
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
void |
onForwardPass(Model model,
java.util.List<org.nd4j.linalg.api.ndarray.INDArray> activations)
Called once per iteration (forward pass) for activations (usually for a
MultiLayerNetwork ),
only at training time |
void |
onForwardPass(Model model,
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> activations)
Called once per iteration (forward pass) for activations (usually for a
ComputationGraph ),
only at training time |
void |
onGradientCalculation(Model model)
Called once per iteration (backward pass) before the gradients are updated
Gradients are available via
Model.gradient() . |
void |
setSessionID(java.lang.String sessionID) |
void |
setStorageRouter(StatsStorageRouter router) |
void |
setUpdateConfig(StatsUpdateConfiguration newConfig) |
void |
setWorkerID(java.lang.String workerID) |
public static final java.lang.String TYPE_ID
public BaseStatsListener(StatsStorageRouter router)
#StatsListener(StatsStorageRouter, int)
with listenerFrequency == 1
router
- Where/how to store the calculated stats. For example, InMemoryStatsStorage
or
FileStatsStorage
public BaseStatsListener(StatsStorageRouter router, int listenerFrequency)
router
- Where/how to store the calculated stats. For example, InMemoryStatsStorage
or
FileStatsStorage
listenerFrequency
- Frequency with which to collect stats informationpublic BaseStatsListener(StatsStorageRouter router, StatsInitializationConfiguration initConfig, StatsUpdateConfiguration updateConfig, java.lang.String sessionID, java.lang.String workerID)
public abstract StatsInitializationReport getNewInitializationReport()
public abstract StatsReport getNewStatsReport()
public abstract StorageMetaData getNewStorageMetaData(long initTime, java.lang.String sessionID, java.lang.String workerID)
public StatsInitializationConfiguration getInitConfig()
public StatsUpdateConfiguration getUpdateConfig()
public void setUpdateConfig(StatsUpdateConfiguration newConfig)
public void setStorageRouter(StatsStorageRouter router)
setStorageRouter
in interface RoutingIterationListener
public StatsStorageRouter getStorageRouter()
getStorageRouter
in interface RoutingIterationListener
public void setWorkerID(java.lang.String workerID)
setWorkerID
in interface RoutingIterationListener
public java.lang.String getWorkerID()
getWorkerID
in interface RoutingIterationListener
public void setSessionID(java.lang.String sessionID)
setSessionID
in interface RoutingIterationListener
public java.lang.String getSessionID()
getSessionID
in interface RoutingIterationListener
public boolean invoked()
IterationListener
invoked
in interface IterationListener
public void invoke()
IterationListener
invoke
in interface IterationListener
public void onEpochStart(Model model)
TrainingListener
MultiLayerNetwork.fit(DataSetIterator)
,
ComputationGraph.fit(DataSetIterator)
or ComputationGraph.fit(MultiDataSetIterator)
onEpochStart
in interface TrainingListener
public void onEpochEnd(Model model)
TrainingListener
MultiLayerNetwork.fit(DataSetIterator)
,
ComputationGraph.fit(DataSetIterator)
or ComputationGraph.fit(MultiDataSetIterator)
onEpochEnd
in interface TrainingListener
public void onForwardPass(Model model, java.util.List<org.nd4j.linalg.api.ndarray.INDArray> activations)
TrainingListener
MultiLayerNetwork
),
only at training timeonForwardPass
in interface TrainingListener
model
- Modelactivations
- Layer activations (including input)public void onForwardPass(Model model, java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> activations)
TrainingListener
ComputationGraph
),
only at training timeonForwardPass
in interface TrainingListener
model
- Modelactivations
- Layer activations (including input)public void onGradientCalculation(Model model)
TrainingListener
Model.gradient()
.
Note that gradients will likely be updated in-place - thus they should be copied or processed synchronously
in this method.
For updates (gradients post learning rate/momentum/rmsprop etc) see TrainingListener.onBackwardPass(Model)
onGradientCalculation
in interface TrainingListener
model
- Modelpublic void onBackwardPass(Model model)
TrainingListener
Model.gradient()
.
Unlike TrainingListener.onGradientCalculation(Model)
the gradients at this point will be post-update, rather than
raw (pre-update) gradients at that method call.
onBackwardPass
in interface TrainingListener
model
- Modelpublic void iterationDone(Model model, int iteration)
IterationListener
iterationDone
in interface IterationListener
model
- the model iteratingiteration
- the iterationpublic abstract BaseStatsListener clone()
clone
in interface RoutingIterationListener
clone
in class java.lang.Object