public class KerasModel
extends java.lang.Object
Modifier and Type | Field and Description |
---|---|
protected java.lang.String |
className |
protected boolean |
enforceTrainingConfig |
static java.lang.String |
HDF5_MODEL_CONFIG_ATTRIBUTE |
static java.lang.String |
HDF5_MODEL_WEIGHTS_ROOT |
static java.lang.String |
HDF5_TRAINING_CONFIG_ATTRIBUTE |
protected java.util.ArrayList<java.lang.String> |
inputLayerNames |
protected java.util.Map<java.lang.String,KerasLayer> |
layers |
protected java.util.List<KerasLayer> |
layersOrdered |
static java.lang.String |
MODEL_CLASS_NAME_MODEL |
static java.lang.String |
MODEL_CLASS_NAME_SEQUENTIAL |
static java.lang.String |
MODEL_CONFIG_FIELD_INPUT_LAYERS |
static java.lang.String |
MODEL_CONFIG_FIELD_LAYERS |
static java.lang.String |
MODEL_CONFIG_FIELD_OUTPUT_LAYERS |
static java.lang.String |
MODEL_FIELD_CLASS_NAME |
static java.lang.String |
MODEL_FIELD_CONFIG |
protected java.util.ArrayList<java.lang.String> |
outputLayerNames |
protected java.util.Map<java.lang.String,InputType> |
outputTypes |
static java.lang.String |
TRAINING_CONFIG_FIELD_LOSS |
protected int |
truncatedBPTT |
protected boolean |
useTruncatedBPTT |
Modifier | Constructor and Description |
---|---|
protected |
KerasModel() |
|
KerasModel(org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder modelBuilder)
(Recommended) Builder-pattern constructor for (Functional API) Model.
|
protected |
KerasModel(java.lang.String modelJson,
java.lang.String modelYaml,
Hdf5Archive weightsArchive,
java.lang.String weightsRoot,
java.lang.String trainingJson,
Hdf5Archive trainingArchive,
boolean enforceTrainingConfig)
(Not recommended) Constructor for (Functional API) Model from model configuration
(JSON or YAML), training configuration (JSON), weights, and "training mode"
boolean indicator.
|
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
getComputationGraph()
Build a ComputationGraph from this Keras Model configuration and import weights.
|
ComputationGraph |
getComputationGraph(boolean importWeights)
Build a ComputationGraph from this Keras Model configuration and (optionally) import weights.
|
ComputationGraphConfiguration |
getComputationGraphConfiguration()
Configure a ComputationGraph from this Keras Model configuration.
|
protected Model |
helperCopyWeightsToModel(Model model)
Helper function to import weights from nested Map into existing model.
|
protected void |
helperImportTrainingConfiguration(java.lang.String trainingConfigJson)
Helper method called from constructor.
|
protected void |
helperImportWeights(Hdf5Archive weightsArchive,
java.lang.String weightsRoot)
Store weights to import with each associated Keras layer.
|
protected void |
helperInferOutputTypes()
Helper method called from constructor.
|
protected void |
helperPrepareLayers(java.util.List<java.lang.Object> layerConfigs)
Helper method called from constructor.
|
static java.util.Map<java.lang.String,java.lang.Object> |
parseJsonString(java.lang.String json)
Convenience function for parsing JSON strings.
|
static java.util.Map<java.lang.String,java.lang.Object> |
parseYamlString(java.lang.String json)
Convenience function for parsing JSON strings.
|
public static final java.lang.String MODEL_FIELD_CLASS_NAME
public static final java.lang.String MODEL_CLASS_NAME_SEQUENTIAL
public static final java.lang.String MODEL_CLASS_NAME_MODEL
public static final java.lang.String MODEL_FIELD_CONFIG
public static final java.lang.String MODEL_CONFIG_FIELD_LAYERS
public static final java.lang.String MODEL_CONFIG_FIELD_INPUT_LAYERS
public static final java.lang.String MODEL_CONFIG_FIELD_OUTPUT_LAYERS
public static final java.lang.String TRAINING_CONFIG_FIELD_LOSS
public static final java.lang.String HDF5_MODEL_WEIGHTS_ROOT
public static final java.lang.String HDF5_MODEL_CONFIG_ATTRIBUTE
public static final java.lang.String HDF5_TRAINING_CONFIG_ATTRIBUTE
protected java.lang.String className
protected boolean enforceTrainingConfig
protected java.util.List<KerasLayer> layersOrdered
protected java.util.Map<java.lang.String,KerasLayer> layers
protected java.util.Map<java.lang.String,InputType> outputTypes
protected java.util.ArrayList<java.lang.String> inputLayerNames
protected java.util.ArrayList<java.lang.String> outputLayerNames
protected boolean useTruncatedBPTT
protected int truncatedBPTT
public KerasModel(org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, java.io.IOException, InvalidKerasConfigurationException
modelBuilder
- builder objectjava.io.IOException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected KerasModel(java.lang.String modelJson, java.lang.String modelYaml, Hdf5Archive weightsArchive, java.lang.String weightsRoot, java.lang.String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig) throws java.io.IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
modelJson
- model configuration JSON stringmodelYaml
- model configuration YAML stringenforceTrainingConfig
- whether to enforce training-related configurationsjava.io.IOException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected KerasModel()
protected void helperPrepareLayers(java.util.List<java.lang.Object> layerConfigs) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
layerConfigs
- List of Keras layer configurationsInvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected void helperImportTrainingConfiguration(java.lang.String trainingConfigJson) throws java.io.IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
trainingConfigJson
- JSON containing Keras training configurationjava.io.IOException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected void helperInferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
protected void helperImportWeights(Hdf5Archive weightsArchive, java.lang.String weightsRoot) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
weightsArchive
- Hdf5ArchiveweightsRoot
- InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public ComputationGraph getComputationGraph(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
importWeights
- whether to import weightsInvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public static java.util.Map<java.lang.String,java.lang.Object> parseJsonString(java.lang.String json) throws java.io.IOException
json
- String containing valid JSONjava.io.IOException
public static java.util.Map<java.lang.String,java.lang.Object> parseYamlString(java.lang.String json) throws java.io.IOException
json
- String containing valid JSONjava.io.IOException
protected Model helperCopyWeightsToModel(Model model) throws InvalidKerasConfigurationException
model
- DL4J Model interfaceInvalidKerasConfigurationException