public interface RecurrentLayer extends Layer
Layer.TrainingMode, Layer.Type
Modifier and Type | Method and Description |
---|---|
org.nd4j.linalg.api.ndarray.INDArray |
rnnActivateUsingStoredState(org.nd4j.linalg.api.ndarray.INDArray input,
boolean training,
boolean storeLastForTBPTT)
Similar to rnnTimeStep, this method is used for activations using the state
stored in the stateMap as the initialization.
|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the RNN stateMap (that contains the stored history for use in methods such
as rnnTimeStep
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the stateMap (stored history).
|
void |
rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
org.nd4j.linalg.api.ndarray.INDArray |
rnnTimeStep(org.nd4j.linalg.api.ndarray.INDArray input)
Do one or more time steps using the previous time step state stored in stateMap.
Can be used to efficiently do forward pass one or n-steps at a time (instead of doing forward pass always from t=0) If stateMap is empty, default initialization (usually zeros) is used Implementations also update stateMap at the end of this method |
Pair<Gradient,org.nd4j.linalg.api.ndarray.INDArray> |
tbpttBackpropGradient(org.nd4j.linalg.api.ndarray.INDArray epsilon,
int tbpttBackLength)
Truncated BPTT equivalent of Layer.backpropGradient().
|
activate, activate, activate, activate, activate, activate, activationMean, backpropGradient, calcGradient, calcL1, calcL2, clone, derivativeActivation, error, feedForwardMaskArray, getIndex, getInputMiniBatchSize, getListeners, getMaskArray, isPretrainLayer, merge, preOutput, preOutput, preOutput, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, transpose, type
accumulateScore, applyLearningRateScoreDecay, batchSize, clear, computeGradientAndScore, conf, fit, fit, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInput
org.nd4j.linalg.api.ndarray.INDArray rnnTimeStep(org.nd4j.linalg.api.ndarray.INDArray input)
input
- Input to this layerjava.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
void rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
void rnnClearPreviousState()
org.nd4j.linalg.api.ndarray.INDArray rnnActivateUsingStoredState(org.nd4j.linalg.api.ndarray.INDArray input, boolean training, boolean storeLastForTBPTT)
input
- Layer inputtraining
- if true: training. Otherwise: teststoreLastForTBPTT
- If true: store the final state in tBpttStateMap for use in truncated BPTT trainingjava.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
void rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
state
- TBPTT state to setPair<Gradient,org.nd4j.linalg.api.ndarray.INDArray> tbpttBackpropGradient(org.nd4j.linalg.api.ndarray.INDArray epsilon, int tbpttBackLength)