public abstract class BaseRecurrentLayer<LayerConfT extends Layer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type| Modifier and Type | Field and Description |
|---|---|
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
conf, dropoutApplied, dropoutMask, gradient, gradientsFlattened, gradientViews, index, input, iterationListeners, maskArray, maskState, optimizer, params, paramsFlattened, score, solver| Constructor and Description |
|---|
BaseRecurrentLayer(NeuralNetConfiguration conf) |
BaseRecurrentLayer(NeuralNetConfiguration conf,
org.nd4j.linalg.api.ndarray.INDArray input) |
| Modifier and Type | Method and Description |
|---|---|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
accumulateScore, activate, activate, activate, activate, activate, activate, activationMean, applyDropOutIfNecessary, applyLearningRateScoreDecay, applyMask, backpropGradient, batchSize, calcGradient, calcL1, calcL2, clear, clone, computeGradientAndScore, conf, createGradient, derivativeActivation, error, feedForwardMaskArray, fit, fit, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, layerConf, layerNameAndIndex, merge, numParams, numParams, params, paramTable, paramTable, preOutput, preOutput, preOutput, preOutput, score, setBackpropGradientsViewArray, setConf, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, transpose, type, update, update, validateInputequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitrnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradientactivate, activate, activate, activate, activate, activate, activationMean, backpropGradient, calcGradient, calcL1, calcL2, clone, derivativeActivation, error, feedForwardMaskArray, getIndex, getInputMiniBatchSize, getListeners, getMaskArray, isPretrainLayer, merge, preOutput, preOutput, preOutput, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, transpose, typeaccumulateScore, applyLearningRateScoreDecay, batchSize, clear, computeGradientAndScore, conf, fit, fit, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInputprotected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap
protected java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> tBpttStateMap
public BaseRecurrentLayer(NeuralNetConfiguration conf)
public BaseRecurrentLayer(NeuralNetConfiguration conf, org.nd4j.linalg.api.ndarray.INDArray input)
public java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
rnnGetPreviousState in interface RecurrentLayerpublic void rnnSetPreviousState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
rnnSetPreviousState in interface RecurrentLayerpublic void rnnClearPreviousState()
rnnClearPreviousState in interface RecurrentLayerpublic java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
RecurrentLayerrnnGetTBPTTState in interface RecurrentLayerpublic void rnnSetTBPTTState(java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> state)
RecurrentLayerrnnSetTBPTTState in interface RecurrentLayerstate - TBPTT state to set