public class SparkCBOW extends BaseSparkLearningAlgorithm
Modifier and Type | Field and Description |
---|---|
protected java.lang.ThreadLocal<org.nd4j.parameterserver.distributed.messages.Frame<org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage>> |
frame |
nextRandom, vectorsConfiguration, vocabCache
Constructor and Description |
---|
SparkCBOW() |
Modifier and Type | Method and Description |
---|---|
org.nd4j.parameterserver.distributed.messages.Frame<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> |
frameSequence(Sequence<ShallowSequenceElement> sequence,
java.util.concurrent.atomic.AtomicLong nextRandom,
double learningRate) |
java.lang.String |
getCodeName() |
org.nd4j.parameterserver.distributed.training.TrainingDriver<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> |
getTrainingDriver() |
protected void |
iterateSample(ShallowSequenceElement currentWord,
int[] windowWords,
java.util.concurrent.atomic.AtomicLong nextRandom,
double alpha,
boolean isInference,
int numLabels,
boolean trainWords,
org.nd4j.linalg.api.ndarray.INDArray inferenceVector) |
applySubsampling, configure, finish, isEarlyTerminationHit, learnSequence, pretrain
protected transient java.lang.ThreadLocal<org.nd4j.parameterserver.distributed.messages.Frame<org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage>> frame
public java.lang.String getCodeName()
public org.nd4j.parameterserver.distributed.messages.Frame<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence, java.util.concurrent.atomic.AtomicLong nextRandom, double learningRate)
protected void iterateSample(ShallowSequenceElement currentWord, int[] windowWords, java.util.concurrent.atomic.AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, org.nd4j.linalg.api.ndarray.INDArray inferenceVector)
public org.nd4j.parameterserver.distributed.training.TrainingDriver<? extends org.nd4j.parameterserver.distributed.messages.TrainingMessage> getTrainingDriver()