public interface RecurrentLayer extends Layer
Layer.TrainingMode, Layer.Type| Modifier and Type | Method and Description |
|---|---|
org.nd4j.linalg.api.ndarray.INDArray |
rnnActivateUsingStoredState(org.nd4j.linalg.api.ndarray.INDArray input,
boolean training,
boolean storeLastForTBPTT)
Similar to rnnTimeStep, this method is used for activations using the state
stored in the stateMap as the initialization.
|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the RNN stateMap (that contains the stored history for use in methods such
as rnnTimeStep
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the stateMap (stored history).
|
void |
rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
org.nd4j.linalg.api.ndarray.INDArray |
rnnTimeStep(org.nd4j.linalg.api.ndarray.INDArray input)
Do one or more time steps using the previous time step state stored in stateMap.
Can be used to efficiently do forward pass one or n-steps at a time (instead of doing forward pass always from t=0) If stateMap is empty, default initialization (usually zeros) is used Implementations also update stateMap at the end of this method |
Pair<Gradient,org.nd4j.linalg.api.ndarray.INDArray> |
tbpttBackpropGradient(org.nd4j.linalg.api.ndarray.INDArray epsilon,
int tbpttBackLength)
Truncated BPTT equivalent of Layer.backpropGradient().
|
activate, activate, activate, activate, activate, activate, activationMean, backpropGradient, calcGradient, calcL1, calcL2, clone, derivativeActivation, error, feedForwardMaskArray, getIndex, getInputMiniBatchSize, getListeners, getMaskArray, isPretrainLayer, merge, preOutput, preOutput, preOutput, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, transpose, typeaccumulateScore, applyLearningRateScoreDecay, batchSize, clear, computeGradientAndScore, conf, fit, fit, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInputorg.nd4j.linalg.api.ndarray.INDArray rnnTimeStep(org.nd4j.linalg.api.ndarray.INDArray input)
input - Input to this layerjava.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
void rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
void rnnClearPreviousState()
org.nd4j.linalg.api.ndarray.INDArray rnnActivateUsingStoredState(org.nd4j.linalg.api.ndarray.INDArray input,
boolean training,
boolean storeLastForTBPTT)
input - Layer inputtraining - if true: training. Otherwise: teststoreLastForTBPTT - If true: store the final state in tBpttStateMap for use in truncated BPTT trainingjava.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
void rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
state - TBPTT state to setPair<Gradient,org.nd4j.linalg.api.ndarray.INDArray> tbpttBackpropGradient(org.nd4j.linalg.api.ndarray.INDArray epsilon, int tbpttBackLength)