public class TransferLearningHelper
extends java.lang.Object
Constructor and Description |
---|
TransferLearningHelper(ComputationGraph orig)
Expects a computation graph where some vertices are frozen
|
TransferLearningHelper(ComputationGraph orig,
java.lang.String... frozenOutputAt)
Will modify the given comp graph (in place!) to freeze vertices from input to the vertex specified.
|
TransferLearningHelper(MultiLayerNetwork orig)
Expects a MLN where some layers are frozen
|
TransferLearningHelper(MultiLayerNetwork orig,
int frozenTill)
Will modify the given MLN (in place!) to freeze layers (hold params constant during training) specified and below
|
Modifier and Type | Method and Description |
---|---|
void |
errorIfGraphIfMLN() |
org.nd4j.linalg.dataset.DataSet |
featurize(org.nd4j.linalg.dataset.DataSet input)
During training frozen vertices/layers can be treated as "featurizing" the input
The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate
quickly on the smaller unfrozen part of the model
Currently does not support datasets with feature masks
|
org.nd4j.linalg.dataset.MultiDataSet |
featurize(org.nd4j.linalg.dataset.MultiDataSet input)
During training frozen vertices/layers can be treated as "featurizing" the input
The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate
quickly on the smaller unfrozen part of the model
Currently does not support datasets with feature masks
|
void |
fitFeaturized(org.nd4j.linalg.dataset.DataSet input) |
void |
fitFeaturized(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter) |
void |
fitFeaturized(org.nd4j.linalg.dataset.MultiDataSet input) |
void |
fitFeaturized(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iter)
Fit from a featurized dataset.
|
org.nd4j.linalg.api.ndarray.INDArray |
outputFromFeaturized(org.nd4j.linalg.api.ndarray.INDArray input)
Use to get the output from a featurized input
|
org.nd4j.linalg.api.ndarray.INDArray[] |
outputFromFeaturized(org.nd4j.linalg.api.ndarray.INDArray[] input)
Use to get the output from a featurized input
|
ComputationGraph |
unfrozenGraph()
Returns the unfrozen subset of the original computation graph as a computation graph
Note that with each call to featurizedFit the parameters to the original computation graph are also updated
|
MultiLayerNetwork |
unfrozenMLN()
Returns the unfrozen layers of the MultiLayerNetwork as a multilayernetwork
Note that with each call to featurizedFit the parameters to the original MLN are also updated
|
public TransferLearningHelper(ComputationGraph orig, java.lang.String... frozenOutputAt)
orig
- Comp graphfrozenOutputAt
- vertex to freeze at (hold params constant during training)public TransferLearningHelper(ComputationGraph orig)
orig
- public TransferLearningHelper(MultiLayerNetwork orig, int frozenTill)
orig
- MLN to freezefrozenTill
- integer indicating the index of the layer and below to freezepublic TransferLearningHelper(MultiLayerNetwork orig)
orig
- public void errorIfGraphIfMLN()
public ComputationGraph unfrozenGraph()
public MultiLayerNetwork unfrozenMLN()
public org.nd4j.linalg.api.ndarray.INDArray[] outputFromFeaturized(org.nd4j.linalg.api.ndarray.INDArray[] input)
input
- featurized datapublic org.nd4j.linalg.api.ndarray.INDArray outputFromFeaturized(org.nd4j.linalg.api.ndarray.INDArray input)
input
- featurized datapublic org.nd4j.linalg.dataset.MultiDataSet featurize(org.nd4j.linalg.dataset.MultiDataSet input)
input
- multidataset to feed into the computation graph with frozen layer verticespublic org.nd4j.linalg.dataset.DataSet featurize(org.nd4j.linalg.dataset.DataSet input)
input
- multidataset to feed into the computation graph with frozen layer verticespublic void fitFeaturized(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iter)
iter
- public void fitFeaturized(org.nd4j.linalg.dataset.MultiDataSet input)
public void fitFeaturized(org.nd4j.linalg.dataset.DataSet input)
public void fitFeaturized(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter)