public abstract class BaseSparkEarlyStoppingTrainer<T extends Model> extends java.lang.Object implements IEarlyStoppingTrainer<T>
MultiLayerNetwork
or a ComputationGraph
Modifier | Constructor and Description |
---|---|
protected |
BaseSparkEarlyStoppingTrainer(org.apache.spark.api.java.JavaSparkContext sc,
EarlyStoppingConfiguration<T> esConfig,
T net,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> train,
org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> trainMulti,
EarlyStoppingListener<T> listener) |
Modifier and Type | Method and Description |
---|---|
EarlyStoppingResult<T> |
fit()
Conduct early stopping training
|
protected abstract void |
fit(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> data) |
protected abstract void |
fitMulti(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> data) |
protected abstract double |
getScore() |
void |
setListener(EarlyStoppingListener<T> listener)
Set the early stopping listener
|
protected BaseSparkEarlyStoppingTrainer(org.apache.spark.api.java.JavaSparkContext sc, EarlyStoppingConfiguration<T> esConfig, T net, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> train, org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> trainMulti, EarlyStoppingListener<T> listener)
protected abstract void fit(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.DataSet> data)
protected abstract void fitMulti(org.apache.spark.api.java.JavaRDD<org.nd4j.linalg.dataset.api.MultiDataSet> data)
protected abstract double getScore()
public EarlyStoppingResult<T> fit()
IEarlyStoppingTrainer
fit
in interface IEarlyStoppingTrainer<T extends Model>
public void setListener(EarlyStoppingListener<T> listener)
IEarlyStoppingTrainer
setListener
in interface IEarlyStoppingTrainer<T extends Model>