public class EarlyStoppingParallelTrainer<T extends Model> extends java.lang.Object implements IEarlyStoppingTrainer<T>
MultiLayerNetwork
or a ComputationGraph
via early stopping.Modifier and Type | Field and Description |
---|---|
protected EarlyStoppingConfiguration<T> |
esConfig |
protected T |
model |
protected IterationTerminationCondition |
terminationReason |
Constructor and Description |
---|
EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration,
T model,
org.nd4j.linalg.dataset.api.iterator.DataSetIterator train,
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator trainMulti,
EarlyStoppingListener<T> listener,
int workers,
int prefetchBuffer,
int averagingFrequency) |
EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration,
T model,
org.nd4j.linalg.dataset.api.iterator.DataSetIterator train,
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator trainMulti,
EarlyStoppingListener<T> listener,
int workers,
int prefetchBuffer,
int averagingFrequency,
boolean reportScoreAfterAveraging,
boolean useLegacyAveraging) |
EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration,
T model,
org.nd4j.linalg.dataset.api.iterator.DataSetIterator train,
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator trainMulti,
int workers,
int prefetchBuffer,
int averagingFrequency) |
Modifier and Type | Method and Description |
---|---|
EarlyStoppingResult<T> |
fit()
Conduct early stopping training
|
boolean |
getTermination() |
void |
incrementIteration() |
protected void |
reset() |
void |
setLatestScore(double latestScore) |
void |
setListener(EarlyStoppingListener<T> listener)
Set the early stopping listener
|
void |
setTermination(boolean terminate) |
protected void |
setTerminationReason(IterationTerminationCondition terminationReason) |
protected final EarlyStoppingConfiguration<T extends Model> esConfig
protected volatile IterationTerminationCondition terminationReason
public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, org.nd4j.linalg.dataset.api.iterator.DataSetIterator train, org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator trainMulti, int workers, int prefetchBuffer, int averagingFrequency)
public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, org.nd4j.linalg.dataset.api.iterator.DataSetIterator train, org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener, int workers, int prefetchBuffer, int averagingFrequency)
public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration<T> earlyStoppingConfiguration, T model, org.nd4j.linalg.dataset.api.iterator.DataSetIterator train, org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator trainMulti, EarlyStoppingListener<T> listener, int workers, int prefetchBuffer, int averagingFrequency, boolean reportScoreAfterAveraging, boolean useLegacyAveraging)
protected void setTerminationReason(IterationTerminationCondition terminationReason)
public EarlyStoppingResult<T> fit()
IEarlyStoppingTrainer
fit
in interface IEarlyStoppingTrainer<T extends Model>
public void setLatestScore(double latestScore)
public void incrementIteration()
public void setTermination(boolean terminate)
public boolean getTermination()
public void setListener(EarlyStoppingListener<T> listener)
IEarlyStoppingTrainer
setListener
in interface IEarlyStoppingTrainer<T extends Model>
protected void reset()