public class EvaluationBinary extends BaseEvaluation<EvaluationBinary>
ROCBinary
is also used internally to calculate AUC for each output, but only when using an
appropriate constructor, EvaluationBinary(int, Integer)
Note that EvaluationBinary supports both per-example and per-output masking.
The most common use case: multi-task networks, where each output is a binary value. This differs from Evaluation
in that Evaluation
is for a single class (binary or non-binary) evaluation.
Modifier and Type | Field and Description |
---|---|
static int |
DEFAULT_PRECISION |
Constructor and Description |
---|
EvaluationBinary(int size) |
EvaluationBinary(int size,
java.lang.Integer rocBinarySteps)
This constructor allows for ROC to be calculated in addition to the standard evaluation metrics, when the
rocBinarySteps arg is non-null.
|
Modifier and Type | Method and Description |
---|---|
double |
accuracy(int outputNum)
Get the accuracy for the specified output
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions) |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions,
org.nd4j.linalg.api.ndarray.INDArray maskArray) |
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions,
org.nd4j.linalg.api.ndarray.INDArray labelsMask) |
double |
f1(int outputNum)
Get the F1 score for the specified output
|
int |
falseNegatives(int outputNum)
Get the false negatives count for the specified output
|
int |
falsePositives(int outputNum)
Get the false positives count for the specified output
|
ROCBinary |
getROCBinary()
Returns the
ROCBinary instance, if present |
void |
merge(EvaluationBinary other) |
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
double |
precision(int outputNum)
Get the precision (tp / (tp + fp)) for the specified output
|
double |
recall(int outputNum)
Get the recall (tp / (tp + fn)) for the specified output
|
void |
setLabelNames(java.util.List<java.lang.String> labels)
Set the label names, for printing via
stats() |
java.lang.String |
stats()
Get a String representation of the EvaluationBinary class, using the default precision
|
java.lang.String |
stats(int printPrecision)
Get a String representation of the EvaluationBinary class, using the specified precision
|
int |
totalCount(int outputNum)
Get the total number of values for the specified column, accounting for any masking
|
int |
trueNegatives(int outputNum)
Get the true negatives count for the specified output
|
int |
truePositives(int outputNum)
Get the true positives count for the specified output
|
eval, evalTimeSeries
public static final int DEFAULT_PRECISION
public EvaluationBinary(int size)
public EvaluationBinary(int size, java.lang.Integer rocBinarySteps)
ROCBinary
for more detailssize
- Number of outputsrocBinarySteps
- Consructor arg for ROCBinary.ROCBinary(int)
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions)
public void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions, org.nd4j.linalg.api.ndarray.INDArray labelsMask)
evalTimeSeries
in interface IEvaluation<EvaluationBinary>
evalTimeSeries
in class BaseEvaluation<EvaluationBinary>
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions, org.nd4j.linalg.api.ndarray.INDArray maskArray)
eval
in interface IEvaluation<EvaluationBinary>
eval
in class BaseEvaluation<EvaluationBinary>
public void merge(EvaluationBinary other)
public int numLabels()
public void setLabelNames(java.util.List<java.lang.String> labels)
stats()
public int totalCount(int outputNum)
public int truePositives(int outputNum)
public int trueNegatives(int outputNum)
public int falsePositives(int outputNum)
public int falseNegatives(int outputNum)
public double accuracy(int outputNum)
public double precision(int outputNum)
public double recall(int outputNum)
public double f1(int outputNum)
public java.lang.String stats()
public java.lang.String stats(int printPrecision)
printPrecision
- The precision (number of decimal places) for the accuracy, f1, etc.