public class ROCMultiClass extends BaseEvaluation<ROCMultiClass>
The ROC curves are produced by treating the predictions as a set of one-vs-all classifiers, and then calculating ROC curves for each. In practice, this means for N classes, we get N ROC curves.
Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother' ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the full data set is not available in memory on any one machine at once).
Constructor and Description |
---|
ROCMultiClass(int thresholdSteps) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC(int classIdx)
Calculate the AUC - Area Under Curve
Utilizes trapezoidal integration internally |
double |
calculateAverageAUC()
Calculate the average (one-vs-all) AUC for all classes
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions)
Evaluate (collect statistics for) the given minibatch of data.
|
java.util.List<ROC.PrecisionRecallPoint> |
getPrecisionRecallCurve(int classIndex) |
java.util.List<ROC.ROCValue> |
getResults(int classIdx)
Get the ROC curve, as a set of points
|
double[][] |
getResultsAsArray(int classIdx)
Get the ROC curve, as a set of (falsePositive, truePositive) points
|
void |
merge(ROCMultiClass other)
Merge this ROCMultiClass instance with another.
|
eval, eval, evalTimeSeries, evalTimeSeries
public ROCMultiClass(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculationpublic void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions)
BaseEvaluation.evalTimeSeries(INDArray, INDArray)
or BaseEvaluation.evalTimeSeries(INDArray, INDArray, INDArray)
labels
- Labels / true outcomespredictions
- Predictionspublic java.util.List<ROC.ROCValue> getResults(int classIdx)
classIdx
- Index of the class to get the (one-vs-all) ROC curpublic double[][] getResultsAsArray(int classIdx)
Returns a 2d array of {falsePositive, truePositive values}.
Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives
public double calculateAUC(int classIdx)
public double calculateAverageAUC()
public java.util.List<ROC.PrecisionRecallPoint> getPrecisionRecallCurve(int classIndex)
public void merge(ROCMultiClass other)
other
- ROCMultiClass instance to combine with this one