public class ROCBinary extends BaseEvaluation<ROCBinary>
Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother' ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the full data set is not available in memory on any one machine at once).
Unlike ROC
(which supports a single binary label (as a single column probability, or 2 column 'softmax' probability
distribution), ROCBinary assumes that all outputs are independent binary variables. This also differs from
ROCMultiClass
, which should be used for multi-class (single non-binary) cases.
ROCBinary supports per-example and per-output masking: for per-output masking, any particular output may be absent (mask value 0) and hence won't be included in the calculated ROC.
Modifier and Type | Class and Description |
---|---|
static class |
ROCBinary.CountsForThreshold |
static class |
ROCBinary.PrecisionRecallPoint |
static class |
ROCBinary.ROCValue |
Modifier and Type | Field and Description |
---|---|
static int |
DEFAULT_PRECISION |
Constructor and Description |
---|
ROCBinary(int thresholdSteps) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC(int outputNum)
Calculate the AUC - Area Under Curve
Utilizes trapezoidal integration internally |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions) |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions,
org.nd4j.linalg.api.ndarray.INDArray maskArray) |
long |
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column
|
long |
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/column
|
java.util.List<ROCBinary.PrecisionRecallPoint> |
getPrecisionRecallCurve(int outputNum)
Get the precision/recall curve, for the specified output
|
java.util.List<ROCBinary.ROCValue> |
getResults(int outputNum)
Get the ROC curve, as a set of points
|
double[][] |
getResultsAsArray(int outputNum)
Get the ROC curve, as a set of (falsePositive, truePositive) points
|
void |
merge(ROCBinary other) |
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
void |
setLabelNames(java.util.List<java.lang.String> labels)
Set the label names, for printing via
stats() |
java.lang.String |
stats() |
java.lang.String |
stats(int printPrecision) |
eval, evalTimeSeries, evalTimeSeries
public static final int DEFAULT_PRECISION
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions)
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions, org.nd4j.linalg.api.ndarray.INDArray maskArray)
eval
in interface IEvaluation<ROCBinary>
eval
in class BaseEvaluation<ROCBinary>
public void merge(ROCBinary other)
public int numLabels()
public long getCountActualPositive(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public long getCountActualNegative(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public java.util.List<ROCBinary.ROCValue> getResults(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public java.util.List<ROCBinary.PrecisionRecallPoint> getPrecisionRecallCurve(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public double[][] getResultsAsArray(int outputNum)
Returns a 2d array of {falsePositive, truePositive values}.
Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives
public double calculateAUC(int outputNum)
outputNum
- public void setLabelNames(java.util.List<java.lang.String> labels)
stats()
public java.lang.String stats()
public java.lang.String stats(int printPrecision)