public class SparkSkipGram extends BaseSparkLearningAlgorithm
Modifier and Type | Field and Description |
---|---|
protected java.util.concurrent.atomic.AtomicLong |
counter |
protected org.nd4j.parameterserver.distributed.training.TrainingDriver<org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage> |
driver |
protected java.lang.ThreadLocal<org.nd4j.parameterserver.distributed.messages.Frame<org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage>> |
frame |
nextRandom, vectorsConfiguration, vocabCache
Constructor and Description |
---|
SparkSkipGram() |
Modifier and Type | Method and Description |
---|---|
org.nd4j.parameterserver.distributed.messages.Frame<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> |
frameSequence(Sequence<ShallowSequenceElement> sequence,
java.util.concurrent.atomic.AtomicLong nextRandom,
double learningRate) |
java.lang.String |
getCodeName() |
org.nd4j.parameterserver.distributed.training.TrainingDriver<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> |
getTrainingDriver() |
protected void |
iterateSample(ShallowSequenceElement word,
ShallowSequenceElement lastWord,
java.util.concurrent.atomic.AtomicLong nextRandom,
double lr) |
applySubsampling, configure, finish, isEarlyTerminationHit, learnSequence, pretrain
protected transient java.util.concurrent.atomic.AtomicLong counter
protected transient java.lang.ThreadLocal<org.nd4j.parameterserver.distributed.messages.Frame<org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage>> frame
protected org.nd4j.parameterserver.distributed.training.TrainingDriver<org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage> driver
public java.lang.String getCodeName()
public org.nd4j.parameterserver.distributed.messages.Frame<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, java.util.concurrent.atomic.AtomicLong nextRandom, double learningRate)
protected void iterateSample(ShallowSequenceElement word, ShallowSequenceElement lastWord, java.util.concurrent.atomic.AtomicLong nextRandom, double lr)
public org.nd4j.parameterserver.distributed.training.TrainingDriver<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> getTrainingDriver()