public class ROC extends BaseEvaluation<ROC>
Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother' ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the full data set is not available in memory on any one machine at once).
The data is assumed to be binary classification - nColumns == 1 (single binary output variable) or nColumns == 2 (probability distribution over 2 classes, with column 1 being values for 'positive' examples)
Modifier and Type | Class and Description |
---|---|
static class |
ROC.CountsForThreshold |
static class |
ROC.PrecisionRecallPoint |
static class |
ROC.ROCValue |
Constructor and Description |
---|
ROC(int thresholdSteps) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC()
Calculate the AUC - Area Under Curve
Utilizes trapezoidal integration internally |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions)
Evaluate (collect statistics for) the given minibatch of data.
|
java.util.List<ROC.PrecisionRecallPoint> |
getPrecisionRecallCurve() |
java.util.List<ROC.ROCValue> |
getResults()
Get the ROC curve, as a set of points
|
double[][] |
getResultsAsArray()
Get the ROC curve, as a set of (falsePositive, truePositive) points
|
void |
merge(ROC other)
Merge this ROC instance with another.
|
eval, eval, evalTimeSeries, evalTimeSeries
public ROC(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculationpublic void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions)
BaseEvaluation.evalTimeSeries(INDArray, INDArray)
or BaseEvaluation.evalTimeSeries(INDArray, INDArray, INDArray)
labels
- Labels / true outcomespredictions
- Predictionspublic java.util.List<ROC.ROCValue> getResults()
public java.util.List<ROC.PrecisionRecallPoint> getPrecisionRecallCurve()
public double[][] getResultsAsArray()
Returns a 2d array of {falsePositive, truePositive values}.
Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives
public double calculateAUC()
public void merge(ROC other)
other
- ROC instance to combine with this one