public class CompositeReconstructionDistribution extends java.lang.Object implements ReconstructionDistribution
GaussianReconstructionDistribution, the next 10 values as binary/Bernoulli (with
a BernoulliReconstructionDistribution)| Modifier and Type | Class and Description |
|---|---|
static class |
CompositeReconstructionDistribution.Builder |
| Constructor and Description |
|---|
CompositeReconstructionDistribution(int[] distributionSizes,
ReconstructionDistribution[] reconstructionDistributions,
int totalSize) |
| Modifier and Type | Method and Description |
|---|---|
org.nd4j.linalg.api.ndarray.INDArray |
computeLossFunctionScoreArray(org.nd4j.linalg.api.ndarray.INDArray data,
org.nd4j.linalg.api.ndarray.INDArray reconstruction) |
int |
distributionInputSize(int dataSize)
Get the number of distribution parameters for the given input data size.
|
org.nd4j.linalg.api.ndarray.INDArray |
exampleNegLogProbability(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Calculate the negative log probability for each example individually
|
org.nd4j.linalg.api.ndarray.INDArray |
generateAtMean(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Generate a sample from P(x|z), where x = E[P(x|z)]
i.e., return the mean value for the distribution
|
org.nd4j.linalg.api.ndarray.INDArray |
generateRandom(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Randomly sample from P(x|z) using the specified distribution parameters
|
org.nd4j.linalg.api.ndarray.INDArray |
gradient(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
Calculate the gradient of the negative log probability with respect to the preOutDistributionParams
|
boolean |
hasLossFunction()
Does this reconstruction distribution has a standard neural network loss function (such as mean squared error,
which is deterministic) or is it a standard VAE with a probabilistic reconstruction distribution?
|
double |
negLogProbability(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams,
boolean average)
Calculate the negative log probability (summed or averaged over each example in the minibatch)
|
public CompositeReconstructionDistribution(int[] distributionSizes,
ReconstructionDistribution[] reconstructionDistributions,
int totalSize)
public org.nd4j.linalg.api.ndarray.INDArray computeLossFunctionScoreArray(org.nd4j.linalg.api.ndarray.INDArray data,
org.nd4j.linalg.api.ndarray.INDArray reconstruction)
public boolean hasLossFunction()
ReconstructionDistributionhasLossFunction in interface ReconstructionDistributionpublic int distributionInputSize(int dataSize)
ReconstructionDistributiondistributionInputSize in interface ReconstructionDistributiondataSize - Size of the data. i.e., nIn valuepublic double negLogProbability(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams,
boolean average)
ReconstructionDistributionnegLogProbability in interface ReconstructionDistributionx - Data to be modelled (reconstructions)preOutDistributionParams - Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian)average - Whether the log probability should be averaged over the minibatch, or simply summed.public org.nd4j.linalg.api.ndarray.INDArray exampleNegLogProbability(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistributionexampleNegLogProbability in interface ReconstructionDistributionx - Data to be modelled (reconstructions)preOutDistributionParams - Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic org.nd4j.linalg.api.ndarray.INDArray gradient(org.nd4j.linalg.api.ndarray.INDArray x,
org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistributiongradient in interface ReconstructionDistributionx - DatapreOutDistributionParams - Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic org.nd4j.linalg.api.ndarray.INDArray generateRandom(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistributiongenerateRandom in interface ReconstructionDistributionpreOutDistributionParams - Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic org.nd4j.linalg.api.ndarray.INDArray generateAtMean(org.nd4j.linalg.api.ndarray.INDArray preOutDistributionParams)
ReconstructionDistributiongenerateAtMean in interface ReconstructionDistributionpreOutDistributionParams - Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation function