public class DM<T extends SequenceElement> extends java.lang.Object implements SequenceLearningAlgorithm<T>
Modifier and Type | Field and Description |
---|---|
protected double[] |
expTable |
protected static double |
MAX_EXP |
protected double |
negative |
protected double |
sampling |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn0 |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn1 |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn1Neg |
protected org.nd4j.linalg.api.ndarray.INDArray |
table |
protected boolean |
useAdaGrad |
protected int |
window |
Constructor and Description |
---|
DM() |
Modifier and Type | Method and Description |
---|---|
void |
configure(VocabCache<T> vocabCache,
WeightLookupTable<T> lookupTable,
VectorsConfiguration configuration) |
void |
dm(int i,
Sequence<T> sequence,
int b,
java.util.concurrent.atomic.AtomicLong nextRandom,
double alpha,
java.util.List<T> labels,
boolean isInference,
org.nd4j.linalg.api.ndarray.INDArray inferenceVector) |
void |
finish() |
java.lang.String |
getCodeName() |
ElementsLearningAlgorithm<T> |
getElementsLearningAlgorithm() |
org.nd4j.linalg.api.ndarray.INDArray |
inferSequence(Sequence<T> sequence,
long nr,
double learningRate,
double minLearningRate,
int iterations)
This method does training on previously unseen paragraph, and returns inferred vector
|
boolean |
isEarlyTerminationHit() |
double |
learnSequence(Sequence<T> sequence,
java.util.concurrent.atomic.AtomicLong nextRandom,
double learningRate)
This method does training over the sequence of elements passed into it
|
void |
pretrain(SequenceIterator<T> iterator) |
protected static double MAX_EXP
protected int window
protected boolean useAdaGrad
protected double negative
protected double sampling
protected double[] expTable
protected org.nd4j.linalg.api.ndarray.INDArray syn0
protected org.nd4j.linalg.api.ndarray.INDArray syn1
protected org.nd4j.linalg.api.ndarray.INDArray syn1Neg
protected org.nd4j.linalg.api.ndarray.INDArray table
public ElementsLearningAlgorithm<T> getElementsLearningAlgorithm()
getElementsLearningAlgorithm
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public java.lang.String getCodeName()
getCodeName
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration)
configure
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public void pretrain(SequenceIterator<T> iterator)
pretrain
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public double learnSequence(Sequence<T> sequence, java.util.concurrent.atomic.AtomicLong nextRandom, double learningRate)
SequenceLearningAlgorithm
learnSequence
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public void dm(int i, Sequence<T> sequence, int b, java.util.concurrent.atomic.AtomicLong nextRandom, double alpha, java.util.List<T> labels, boolean isInference, org.nd4j.linalg.api.ndarray.INDArray inferenceVector)
public boolean isEarlyTerminationHit()
isEarlyTerminationHit
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public org.nd4j.linalg.api.ndarray.INDArray inferSequence(Sequence<T> sequence, long nr, double learningRate, double minLearningRate, int iterations)
inferSequence
in interface SequenceLearningAlgorithm<T extends SequenceElement>
sequence
- nr
- learningRate
- public void finish()
finish
in interface SequenceLearningAlgorithm<T extends SequenceElement>