public class InMemoryLookupTable<T extends SequenceElement> extends java.lang.Object implements WeightLookupTable<T>
Modifier and Type | Class and Description |
---|---|
static class |
InMemoryLookupTable.Builder<T extends SequenceElement> |
protected class |
InMemoryLookupTable.WeightIterator |
Modifier and Type | Field and Description |
---|---|
protected org.nd4j.linalg.learning.AdaGrad |
adaGrad |
protected java.util.Map<java.lang.Integer,org.nd4j.linalg.api.ndarray.INDArray> |
codes |
protected double[] |
expTable |
protected com.google.common.util.concurrent.AtomicDouble |
lr |
protected static double |
MAX_EXP |
protected double |
negative |
protected org.nd4j.linalg.api.rng.Random |
rng |
protected long |
seed |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn0 |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn1 |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn1Neg |
protected org.nd4j.linalg.api.ndarray.INDArray |
table |
protected java.lang.Long |
tableId |
protected boolean |
useAdaGrad |
protected boolean |
useHS |
protected int |
vectorLength |
protected VocabCache<T> |
vocab |
Constructor and Description |
---|
InMemoryLookupTable() |
InMemoryLookupTable(VocabCache<T> vocab,
int vectorLength,
boolean useAdaGrad,
double lr,
org.nd4j.linalg.api.rng.Random gen,
double negative) |
InMemoryLookupTable(VocabCache<T> vocab,
int vectorLength,
boolean useAdaGrad,
double lr,
org.nd4j.linalg.api.rng.Random gen,
double negative,
boolean useHS) |
Modifier and Type | Method and Description |
---|---|
void |
consume(InMemoryLookupTable<T> srcTable)
This method consumes weights of a given InMemoryLookupTable
PLEASE NOTE: this method explicitly resets current weights
|
java.util.Map<java.lang.Integer,org.nd4j.linalg.api.ndarray.INDArray> |
getCodes() |
double[] |
getExpTable() |
double |
getGradient(int column,
double gradient)
Returns gradient for specified word
|
com.google.common.util.concurrent.AtomicDouble |
getLr()
Deprecated.
|
double |
getNegative() |
org.nd4j.linalg.api.ndarray.INDArray |
getSyn0() |
org.nd4j.linalg.api.ndarray.INDArray |
getSyn1() |
org.nd4j.linalg.api.ndarray.INDArray |
getSyn1Neg() |
org.nd4j.linalg.api.ndarray.INDArray |
getTable() |
VocabCache |
getVocab() |
VocabCache<T> |
getVocabCache()
Returns corresponding vocabulary
|
org.nd4j.linalg.api.ndarray.INDArray |
getWeights() |
protected void |
initAdaGrad() |
protected void |
initExpTable() |
void |
initNegative() |
boolean |
isUseAdaGrad() |
void |
iterate(T w1,
T w2)
Deprecated.
|
void |
iterateSample(T w1,
T w2,
java.util.concurrent.atomic.AtomicLong nextRandom,
double alpha)
Deprecated.
|
int |
layerSize()
The layer size for the lookup table
|
org.nd4j.linalg.api.ndarray.INDArray |
loadCodes(int[] codes)
Loads the co-occurrences for the given codes
|
protected void |
makeTable(int tableSize,
double power) |
void |
plotVocab(BarnesHutTsne tsne,
int numWords,
java.io.File file)
Render the words via TSNE
|
void |
plotVocab(BarnesHutTsne tsne,
int numWords,
UiConnectionInfo connectionInfo)
Render the words via TSNE
|
void |
plotVocab(int numWords,
java.io.File file)
Render the words via tsne
|
void |
plotVocab(int numWords,
UiConnectionInfo connectionInfo)
Render the words via tsne
|
void |
putCode(int codeIndex,
org.nd4j.linalg.api.ndarray.INDArray code) |
void |
putVector(java.lang.String word,
org.nd4j.linalg.api.ndarray.INDArray vector)
Inserts a word vector
|
void |
resetWeights()
Reset the weights of the cache
|
void |
resetWeights(boolean reset)
Clear out all weights regardless
|
void |
setCodes(java.util.Map<java.lang.Integer,org.nd4j.linalg.api.ndarray.INDArray> codes) |
void |
setExpTable(double[] expTable) |
void |
setLearningRate(double lr)
Sets the learning rate
|
void |
setLr(com.google.common.util.concurrent.AtomicDouble lr) |
void |
setNegative(double negative) |
void |
setSyn0(org.nd4j.linalg.api.ndarray.INDArray syn0) |
void |
setSyn1(org.nd4j.linalg.api.ndarray.INDArray syn1) |
void |
setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg) |
void |
setTable(org.nd4j.linalg.api.ndarray.INDArray table) |
void |
setUseAdaGrad(boolean useAdaGrad) |
void |
setUseHS(boolean useHS) |
void |
setVectorLength(int vectorLength) |
void |
setVocab(VocabCache vocab) |
java.lang.String |
toString() |
org.nd4j.linalg.api.ndarray.INDArray |
vector(java.lang.String word) |
java.util.Iterator<org.nd4j.linalg.api.ndarray.INDArray> |
vectors()
Iterates through all of the vectors in the cache
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
getTableId, setTableId
protected org.nd4j.linalg.api.ndarray.INDArray syn0
protected org.nd4j.linalg.api.ndarray.INDArray syn1
protected int vectorLength
protected transient org.nd4j.linalg.api.rng.Random rng
protected com.google.common.util.concurrent.AtomicDouble lr
protected double[] expTable
protected static double MAX_EXP
protected long seed
protected org.nd4j.linalg.api.ndarray.INDArray table
protected org.nd4j.linalg.api.ndarray.INDArray syn1Neg
protected boolean useAdaGrad
protected double negative
protected boolean useHS
protected VocabCache<T extends SequenceElement> vocab
protected java.util.Map<java.lang.Integer,org.nd4j.linalg.api.ndarray.INDArray> codes
protected org.nd4j.linalg.learning.AdaGrad adaGrad
protected java.lang.Long tableId
public InMemoryLookupTable()
public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative, boolean useHS)
public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative)
protected void initAdaGrad()
public double[] getExpTable()
public void setExpTable(double[] expTable)
public double getGradient(int column, double gradient)
WeightLookupTable
getGradient
in interface WeightLookupTable<T extends SequenceElement>
public int layerSize()
WeightLookupTable
layerSize
in interface WeightLookupTable<T extends SequenceElement>
public void resetWeights(boolean reset)
WeightLookupTable
resetWeights
in interface WeightLookupTable<T extends SequenceElement>
public void plotVocab(BarnesHutTsne tsne, int numWords, java.io.File file)
WeightLookupTable
plotVocab
in interface WeightLookupTable<T extends SequenceElement>
tsne
- the tsne to usepublic void plotVocab(int numWords, java.io.File file)
plotVocab
in interface WeightLookupTable<T extends SequenceElement>
public void plotVocab(int numWords, UiConnectionInfo connectionInfo)
plotVocab
in interface WeightLookupTable<T extends SequenceElement>
public void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo)
plotVocab
in interface WeightLookupTable<T extends SequenceElement>
tsne
- the tsne to usenumWords
- connectionInfo
- public void putCode(int codeIndex, org.nd4j.linalg.api.ndarray.INDArray code)
putCode
in interface WeightLookupTable<T extends SequenceElement>
codeIndex
- code
- public org.nd4j.linalg.api.ndarray.INDArray loadCodes(int[] codes)
loadCodes
in interface WeightLookupTable<T extends SequenceElement>
codes
- the codes to loadpublic void initNegative()
protected void initExpTable()
@Deprecated public void iterateSample(T w1, T w2, java.util.concurrent.atomic.AtomicLong nextRandom, double alpha)
iterateSample
in interface WeightLookupTable<T extends SequenceElement>
w1
- the first word to iterate onw2
- the second word to iterate onnextRandom
- next random for samplingalpha
- the alpha to use for learningpublic boolean isUseAdaGrad()
public void setUseAdaGrad(boolean useAdaGrad)
public double getNegative()
public void setUseHS(boolean useHS)
public void setNegative(double negative)
@Deprecated public void iterate(T w1, T w2)
iterate
in interface WeightLookupTable<T extends SequenceElement>
w1
- the first word to iterate onw2
- the second word to iterate onpublic void resetWeights()
resetWeights
in interface WeightLookupTable<T extends SequenceElement>
protected void makeTable(int tableSize, double power)
public void putVector(java.lang.String word, org.nd4j.linalg.api.ndarray.INDArray vector)
putVector
in interface WeightLookupTable<T extends SequenceElement>
word
- the word to insertvector
- the vector to insertpublic org.nd4j.linalg.api.ndarray.INDArray getTable()
public void setTable(org.nd4j.linalg.api.ndarray.INDArray table)
public org.nd4j.linalg.api.ndarray.INDArray getSyn1Neg()
public void setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg)
public org.nd4j.linalg.api.ndarray.INDArray vector(java.lang.String word)
vector
in interface WeightLookupTable<T extends SequenceElement>
word
- public void setLearningRate(double lr)
WeightLookupTable
setLearningRate
in interface WeightLookupTable<T extends SequenceElement>
public java.util.Iterator<org.nd4j.linalg.api.ndarray.INDArray> vectors()
WeightLookupTable
vectors
in interface WeightLookupTable<T extends SequenceElement>
public org.nd4j.linalg.api.ndarray.INDArray getWeights()
getWeights
in interface WeightLookupTable<T extends SequenceElement>
public org.nd4j.linalg.api.ndarray.INDArray getSyn0()
public void setSyn0(org.nd4j.linalg.api.ndarray.INDArray syn0)
public org.nd4j.linalg.api.ndarray.INDArray getSyn1()
public void setSyn1(org.nd4j.linalg.api.ndarray.INDArray syn1)
public VocabCache<T> getVocabCache()
WeightLookupTable
getVocabCache
in interface WeightLookupTable<T extends SequenceElement>
public void setVectorLength(int vectorLength)
@Deprecated public com.google.common.util.concurrent.AtomicDouble getLr()
public void setLr(com.google.common.util.concurrent.AtomicDouble lr)
public VocabCache getVocab()
public void setVocab(VocabCache vocab)
public java.util.Map<java.lang.Integer,org.nd4j.linalg.api.ndarray.INDArray> getCodes()
public void setCodes(java.util.Map<java.lang.Integer,org.nd4j.linalg.api.ndarray.INDArray> codes)
public java.lang.String toString()
toString
in class java.lang.Object
public void consume(InMemoryLookupTable<T> srcTable)
srcTable
-