public class DefaultGradient extends java.lang.Object implements Gradient
| Modifier and Type | Field and Description |
|---|---|
static char |
DEFAULT_FLATTENING_ORDER |
| Constructor and Description |
|---|
DefaultGradient() |
DefaultGradient(org.nd4j.linalg.api.ndarray.INDArray flattenedGradient) |
| Modifier and Type | Method and Description |
|---|---|
void |
clear()
Clear residual parameters (useful for returning a gradient and then clearing old objects)
|
java.lang.Character |
flatteningOrderForVariable(java.lang.String variable)
Return the gradient flattening order for the specified variable, or null if it is not explicitly set
|
org.nd4j.linalg.api.ndarray.INDArray |
getGradientFor(java.lang.String variable)
The gradient for the given variable
|
org.nd4j.linalg.api.ndarray.INDArray |
gradient()
The full gradient as one flat vector
|
org.nd4j.linalg.api.ndarray.INDArray |
gradient(java.util.List<java.lang.String> order)
The full gradient as one flat vector
|
java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> |
gradientForVariable()
Gradient look up table
|
org.nd4j.linalg.api.ndarray.INDArray |
setGradientFor(java.lang.String variable,
org.nd4j.linalg.api.ndarray.INDArray newGradient)
Update gradient for the given variable
|
org.nd4j.linalg.api.ndarray.INDArray |
setGradientFor(java.lang.String variable,
org.nd4j.linalg.api.ndarray.INDArray gradient,
java.lang.Character flatteningOrder)
Update gradient for the given variable; also (optionally) specify the order in which the array should be flattened
to a row vector
|
java.lang.String |
toString() |
public static final char DEFAULT_FLATTENING_ORDER
public DefaultGradient()
public DefaultGradient(org.nd4j.linalg.api.ndarray.INDArray flattenedGradient)
public java.util.Map<java.lang.String,org.nd4j.linalg.api.ndarray.INDArray> gradientForVariable()
GradientgradientForVariable in interface Gradientpublic org.nd4j.linalg.api.ndarray.INDArray gradient(java.util.List<java.lang.String> order)
Gradientpublic org.nd4j.linalg.api.ndarray.INDArray gradient()
Gradientpublic void clear()
Gradientpublic org.nd4j.linalg.api.ndarray.INDArray getGradientFor(java.lang.String variable)
GradientgetGradientFor in interface Gradientvariable - the variable to get the gradient forpublic org.nd4j.linalg.api.ndarray.INDArray setGradientFor(java.lang.String variable,
org.nd4j.linalg.api.ndarray.INDArray newGradient)
GradientsetGradientFor in interface Gradientvariable - the variable to get the gradient fornewGradient - the gradient valuespublic org.nd4j.linalg.api.ndarray.INDArray setGradientFor(java.lang.String variable,
org.nd4j.linalg.api.ndarray.INDArray gradient,
java.lang.Character flatteningOrder)
GradientsetGradientFor in interface Gradientvariable - the variable to get the gradient forgradient - the gradient valuesflatteningOrder - the order in which gradients should be flattened (null ok - default)public java.lang.Character flatteningOrderForVariable(java.lang.String variable)
GradientflatteningOrderForVariable in interface Gradientvariable - Variable to return the gradient flattening order forpublic java.lang.String toString()
toString in class java.lang.Object