public class InMemoryGraphLookupTable extends java.lang.Object implements GraphVectorLookupTable
Modifier and Type | Field and Description |
---|---|
protected double[] |
expTable |
protected double |
learningRate |
protected static double |
MAX_EXP |
protected int |
nVertices |
protected org.nd4j.linalg.api.ndarray.INDArray |
outWeights |
protected BinaryTree |
tree |
protected int |
vectorSize |
protected org.nd4j.linalg.api.ndarray.INDArray |
vertexVectors |
Constructor and Description |
---|
InMemoryGraphLookupTable(int nVertices,
int vectorSize,
BinaryTree tree,
double learningRate) |
Modifier and Type | Method and Description |
---|---|
double |
calculateProb(int first,
int second)
Calculate the probability of the second vertex given the first vertex
i.e., P(v_second | v_first)
|
double |
calculateScore(int first,
int second)
Calculate score.
|
org.nd4j.linalg.api.ndarray.INDArray |
getInnerNodeVector(int innerNode) |
int |
getNumVertices()
Returns the number of vertices in the graph
|
org.nd4j.linalg.api.ndarray.INDArray |
getOutWeights() |
BinaryTree |
getTree() |
org.nd4j.linalg.api.ndarray.INDArray |
getVector(int idx)
Get the vector for the vertex with index idx
|
org.nd4j.linalg.api.ndarray.INDArray |
getVertexVectors() |
void |
iterate(int first,
int second)
Conduct learning given a pair of vertices (in and out)
|
void |
resetWeights()
Reset (randomize) the weights.
|
void |
setLearningRate(double learningRate)
Set the learning rate
|
void |
setVertexVectors(org.nd4j.linalg.api.ndarray.INDArray vertexVectors) |
org.nd4j.linalg.api.ndarray.INDArray[][] |
vectorsAndGradients(int first,
int second)
Returns vertex vector and vector gradients, plus inner node vectors and inner node gradients
Specifically, out[0] are vectors, out[1] are gradients for the corresponding vectors out[0][0] is vector for first vertex; out[0][1] is gradient for this vertex vector out[0][i] (i>0) is the inner node vector along path to second vertex; out[1][i] is gradient for inner node vertex This design is used primarily to aid in testing (numerical gradient checks) |
int |
vectorSize()
The size of the vector representations
|
protected int nVertices
protected int vectorSize
protected BinaryTree tree
protected org.nd4j.linalg.api.ndarray.INDArray vertexVectors
protected org.nd4j.linalg.api.ndarray.INDArray outWeights
protected double learningRate
protected double[] expTable
protected static double MAX_EXP
public InMemoryGraphLookupTable(int nVertices, int vectorSize, BinaryTree tree, double learningRate)
public org.nd4j.linalg.api.ndarray.INDArray getVertexVectors()
public org.nd4j.linalg.api.ndarray.INDArray getOutWeights()
public int vectorSize()
GraphVectorLookupTable
vectorSize
in interface GraphVectorLookupTable
public void resetWeights()
GraphVectorLookupTable
resetWeights
in interface GraphVectorLookupTable
public void iterate(int first, int second)
GraphVectorLookupTable
iterate
in interface GraphVectorLookupTable
public org.nd4j.linalg.api.ndarray.INDArray[][] vectorsAndGradients(int first, int second)
first
- first (input) vertex indexsecond
- second (output) vertex indexpublic double calculateProb(int first, int second)
first
- index of the first vertexsecond
- index of the second vertexpublic double calculateScore(int first, int second)
public BinaryTree getTree()
public org.nd4j.linalg.api.ndarray.INDArray getInnerNodeVector(int innerNode)
public org.nd4j.linalg.api.ndarray.INDArray getVector(int idx)
GraphVectorLookupTable
getVector
in interface GraphVectorLookupTable
public void setLearningRate(double learningRate)
GraphVectorLookupTable
setLearningRate
in interface GraphVectorLookupTable
public int getNumVertices()
GraphVectorLookupTable
getNumVertices
in interface GraphVectorLookupTable
public void setVertexVectors(org.nd4j.linalg.api.ndarray.INDArray vertexVectors)