public class GaussianReconstructionDistribution extends java.lang.Object implements ReconstructionDistribution
Specifically, the GaussianReconstructionDistribution models mean and log(stdev^2). This parameterization gives log(1) = 0,
and inputs can be in range (-infinity,infinity). Other parameterizations for variance are of course possible but may be
problematic with respect to the average pre-activation function values and activation function ranges.
For activation functions, identity and perhaps tanh are typical - though tanh (unlike identity) implies a minimum/maximum
possible value for mean and log variance. Asymmetric activation functions such as sigmoid or relu should be avoided.
Constructor and Description |
---|
GaussianReconstructionDistribution()
Create a GaussianReconstructionDistribution with the default identity activation function.
|
GaussianReconstructionDistribution(org.nd4j.linalg.activations.Activation activationFn) |
GaussianReconstructionDistribution(org.nd4j.linalg.activations.IActivation activationFn) |
GaussianReconstructionDistribution(java.lang.String activationFn)
Deprecated.
|
Modifier and Type | Method and Description |
---|---|
int |
distributionInputSize(int dataSize)
Get the number of distribution parameters for the given input data size.
|
org.nd4j.linalg.api.ndarray.INDArray |
exampleNegLogProbability(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Calculate the negative log probability for each example individually
|
org.nd4j.linalg.api.ndarray.INDArray |
generateAtMean(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Generate a sample from P(x|z), where x = E[P(x|z)]
i.e., return the mean value for the distribution
|
org.nd4j.linalg.api.ndarray.INDArray |
generateRandom(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Randomly sample from P(x|z) using the specified distribution parameters
|
org.nd4j.linalg.api.ndarray.INDArray |
gradient(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Calculate the gradient of the negative log probability with respect to the preOutDistributionParams
|
boolean |
hasLossFunction()
Does this reconstruction distribution has a standard neural network loss function (such as mean squared error,
which is deterministic) or is it a standard VAE with a probabilistic reconstruction distribution?
|
double |
negLogProbability(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams,
boolean average)
Calculate the negative log probability (summed or averaged over each example in the minibatch)
|
java.lang.String |
toString() |
public GaussianReconstructionDistribution()
@Deprecated public GaussianReconstructionDistribution(java.lang.String activationFn)
GaussianReconstructionDistribution(Activation)
public GaussianReconstructionDistribution(org.nd4j.linalg.activations.Activation activationFn)
activationFn
- Activation function for the reconstruction distribution. Typically identity or tanh.public GaussianReconstructionDistribution(org.nd4j.linalg.activations.IActivation activationFn)
activationFn
- Activation function for the reconstruction distribution. Typically identity or tanh.public boolean hasLossFunction()
ReconstructionDistribution
hasLossFunction
in interface ReconstructionDistribution
public int distributionInputSize(int dataSize)
ReconstructionDistribution
distributionInputSize
in interface ReconstructionDistribution
dataSize
- Size of the data. i.e., nIn valuepublic double negLogProbability(org.nd4j.linalg.api.ndarray.INDArray x, org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams, boolean average)
ReconstructionDistribution
negLogProbability
in interface ReconstructionDistribution
x
- Data to be modelled (reconstructions)preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian)average
- Whether the log probability should be averaged over the minibatch, or simply summed.public org.nd4j.linalg.api.ndarray.INDArray exampleNegLogProbability(org.nd4j.linalg.api.ndarray.INDArray x, org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistribution
exampleNegLogProbability
in interface ReconstructionDistribution
x
- Data to be modelled (reconstructions)preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic org.nd4j.linalg.api.ndarray.INDArray gradient(org.nd4j.linalg.api.ndarray.INDArray x, org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistribution
gradient
in interface ReconstructionDistribution
x
- DatapreOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic org.nd4j.linalg.api.ndarray.INDArray generateRandom(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistribution
generateRandom
in interface ReconstructionDistribution
preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic org.nd4j.linalg.api.ndarray.INDArray generateAtMean(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistribution
generateAtMean
in interface ReconstructionDistribution
preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic java.lang.String toString()
toString
in class java.lang.Object