public abstract class BaseRecurrentLayer<LayerConfT extends Layer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type
Modifier and Type | Field and Description |
---|---|
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
conf, dropoutApplied, dropoutMask, gradient, gradientsFlattened, gradientViews, index, input, iterationListeners, maskArray, maskState, optimizer, params, paramsFlattened, score, solver
Constructor and Description |
---|
BaseRecurrentLayer(NeuralNetConfiguration conf) |
BaseRecurrentLayer(NeuralNetConfiguration conf,
org.nd4j.linalg.api.ndarray.INDArray input) |
Modifier and Type | Method and Description |
---|---|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
accumulateScore, activate, activate, activate, activate, activate, activate, activationMean, applyDropOutIfNecessary, applyLearningRateScoreDecay, applyMask, backpropGradient, batchSize, calcGradient, calcL1, calcL2, clear, clone, computeGradientAndScore, conf, createGradient, derivativeActivation, error, feedForwardMaskArray, fit, fit, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, layerConf, layerNameAndIndex, merge, numParams, numParams, params, paramTable, paramTable, preOutput, preOutput, preOutput, preOutput, score, setBackpropGradientsViewArray, setConf, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, transpose, type, update, update, validateInput
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
rnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradient
activate, activate, activate, activate, activate, activate, activationMean, backpropGradient, calcGradient, calcL1, calcL2, clone, derivativeActivation, error, feedForwardMaskArray, getIndex, getInputMiniBatchSize, getListeners, getMaskArray, isPretrainLayer, merge, preOutput, preOutput, preOutput, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, transpose, type
accumulateScore, applyLearningRateScoreDecay, batchSize, clear, computeGradientAndScore, conf, fit, fit, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInput
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> tBpttStateMap
public BaseRecurrentLayer(NeuralNetConfiguration conf)
public BaseRecurrentLayer(NeuralNetConfiguration conf, org.nd4j.linalg.api.ndarray.INDArray input)
public java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
rnnGetPreviousState
in interface RecurrentLayer
public void rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
rnnSetPreviousState
in interface RecurrentLayer
public void rnnClearPreviousState()
rnnClearPreviousState
in interface RecurrentLayer
public java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
RecurrentLayer
rnnGetTBPTTState
in interface RecurrentLayer
public void rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
RecurrentLayer
rnnSetTBPTTState
in interface RecurrentLayer
state
- TBPTT state to set