public class Evaluation extends BaseEvaluation<Evaluation>
Modifier and Type | Field and Description |
---|---|
protected ConfusionMatrix<java.lang.Integer> |
confusion |
protected java.util.Map<Pair<java.lang.Integer,java.lang.Integer>,java.util.List<java.lang.Object>> |
confusionMatrixMetaData |
protected static double |
DEFAULT_EDGE_VALUE |
protected Counter<java.lang.Integer> |
falseNegatives |
protected Counter<java.lang.Integer> |
falsePositives |
protected java.util.List<java.lang.String> |
labelsList |
protected int |
numRowCounter |
protected int |
topN |
protected int |
topNCorrectCount |
protected int |
topNTotalCount |
protected Counter<java.lang.Integer> |
trueNegatives |
protected Counter<java.lang.Integer> |
truePositives |
Constructor and Description |
---|
Evaluation() |
Evaluation(int numClasses)
The number of classes to account
for in the evaluation
|
Evaluation(java.util.List<java.lang.String> labels)
The labels to include with the evaluation.
|
Evaluation(java.util.List<java.lang.String> labels,
int topN)
Constructor to use for top N accuracy
|
Evaluation(java.util.Map<java.lang.Integer,java.lang.String> labels)
Use a map to generate labels
Pass in a label index with the actual label
you want to use for output
|
Modifier and Type | Method and Description |
---|---|
double |
accuracy()
Accuracy:
(TP + TN) / (P + N)
|
void |
addToConfusion(java.lang.Integer real,
java.lang.Integer guess)
Adds to the confusion matrix
|
int |
classCount(java.lang.Integer clazz)
Returns the number of times the given label
has actually occurred
|
java.lang.String |
confusionToString()
Get a String representation of the confusion matrix
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes,
org.nd4j.linalg.api.ndarray.INDArray guesses)
Collects statistics on the real outcomes vs the
guesses.
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels,
org.nd4j.linalg.api.ndarray.INDArray input,
ComputationGraph network)
Evaluate the output
using the given true labels,
the input to the multi layer network
and the multi layer network to
use for evaluation
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes,
org.nd4j.linalg.api.ndarray.INDArray guesses,
java.util.List<? extends java.io.Serializable> recordMetaData)
Evaluate the network, with optional metadata
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels,
org.nd4j.linalg.api.ndarray.INDArray input,
MultiLayerNetwork network)
Evaluate the output
using the given true labels,
the input to the multi layer network
and the multi layer network to
use for evaluation
|
void |
eval(int predictedIdx,
int actualIdx)
Evaluate a single prediction (one prediction at a time)
|
double |
f1()
TP: true positive
FP: False Positive
FN: False Negative
F1 score: 2 * TP / (2TP + FP + FN)
|
double |
f1(java.lang.Integer classLabel)
Calculate f1 score for a given class
|
double |
falseAlarmRate()
False Alarm Rate (FAR) reflects rate of misclassified to classified records
http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
|
double |
falseNegativeRate()
False negative rate based on guesses so far
Takes into account all known classes and outputs average fnr across all of them
|
double |
falseNegativeRate(java.lang.Integer classLabel)
Returns the false negative rate for a given label
|
double |
falseNegativeRate(java.lang.Integer classLabel,
double edgeCase)
Returns the false negative rate for a given label
|
java.util.Map<java.lang.Integer,java.lang.Integer> |
falseNegatives()
False negatives: correctly rejected
|
double |
falsePositiveRate()
False positive rate based on guesses so far
Takes into account all known classes and outputs average fpr across all of them
|
double |
falsePositiveRate(java.lang.Integer classLabel)
Returns the false positive rate for a given label
|
double |
falsePositiveRate(java.lang.Integer classLabel,
double edgeCase)
Returns the false positive rate for a given label
|
java.util.Map<java.lang.Integer,java.lang.Integer> |
falsePositives()
False positive: wrong guess
|
java.lang.String |
getClassLabel(java.lang.Integer clazz) |
ConfusionMatrix<java.lang.Integer> |
getConfusionMatrix()
Returns the confusion matrix variable
|
int |
getNumRowCounter() |
java.util.List<Prediction> |
getPredictionByPredictedClass(int predictedClass)
Get a list of predictions, for all data with the specified predicted class, regardless of the actual data
class.
|
java.util.List<Prediction> |
getPredictionErrors()
Get a list of prediction errors, on a per-record basis
|
java.util.List<Prediction> |
getPredictions(int actualClass,
int predictedClass)
Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)
|
java.util.List<Prediction> |
getPredictionsByActualClass(int actualClass)
Get a list of predictions, for all data with the specified actual class, regardless of the predicted
class.
|
int |
getTopNCorrectCount()
Return the number of correct predictions according to top N value.
|
int |
getTopNTotalCount()
Return the total number of top N evaluations.
|
void |
incrementFalseNegatives(java.lang.Integer classLabel) |
void |
incrementFalsePositives(java.lang.Integer classLabel) |
void |
incrementTrueNegatives(java.lang.Integer classLabel) |
void |
incrementTruePositives(java.lang.Integer classLabel) |
void |
merge(Evaluation other)
Merge the other evaluation object into this one.
|
java.util.Map<java.lang.Integer,java.lang.Integer> |
negative()
Total negatives true negatives + false negatives
|
java.util.Map<java.lang.Integer,java.lang.Integer> |
positive()
Returns all of the positive guesses:
true positive + false negative
|
double |
precision()
Precision based on guesses so far
Takes into account all known classes and outputs average precision across all of them
|
double |
precision(java.lang.Integer classLabel)
Returns the precision for a given label
|
double |
precision(java.lang.Integer classLabel,
double edgeCase)
Returns the precision for a given label
|
double |
recall()
Recall based on guesses so far
Takes into account all known classes and outputs average recall across all of them
|
double |
recall(java.lang.Integer classLabel)
Returns the recall for a given label
|
double |
recall(java.lang.Integer classLabel,
double edgeCase)
Returns the recall for a given label
|
java.lang.String |
stats() |
java.lang.String |
stats(boolean suppressWarnings)
Method to obtain the classification report as a String
|
double |
topNAccuracy()
Top N accuracy of the predictions so far.
|
java.util.Map<java.lang.Integer,java.lang.Integer> |
trueNegatives()
True negatives: correctly rejected
|
java.util.Map<java.lang.Integer,java.lang.Integer> |
truePositives()
True positives: correctly rejected
|
eval, evalTimeSeries, evalTimeSeries
protected final int topN
protected int topNCorrectCount
protected int topNTotalCount
protected Counter<java.lang.Integer> truePositives
protected Counter<java.lang.Integer> falsePositives
protected Counter<java.lang.Integer> trueNegatives
protected Counter<java.lang.Integer> falseNegatives
protected ConfusionMatrix<java.lang.Integer> confusion
protected int numRowCounter
protected java.util.List<java.lang.String> labelsList
protected static final double DEFAULT_EDGE_VALUE
protected java.util.Map<Pair<java.lang.Integer,java.lang.Integer>,java.util.List<java.lang.Object>> confusionMatrixMetaData
public Evaluation()
public Evaluation(int numClasses)
numClasses
- the number of classes to account for in the evaluationpublic Evaluation(java.util.List<java.lang.String> labels)
labels
- the labels to use
for the outputpublic Evaluation(java.util.Map<java.lang.Integer,java.lang.String> labels)
labels
- a map of label index to label valuepublic Evaluation(java.util.List<java.lang.String> labels, int topN)
labels
- Labels for the classes (may be null)topN
- Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N
accuracy, an example is considered 'correct' if the probability for the true class is one of the
highest N valuespublic void eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels, org.nd4j.linalg.api.ndarray.INDArray input, ComputationGraph network)
trueLabels
- the labels to iseinput
- the input to the network to use
for evaluationnetwork
- the network to use for outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels, org.nd4j.linalg.api.ndarray.INDArray input, MultiLayerNetwork network)
trueLabels
- the labels to iseinput
- the input to the network to use
for evaluationnetwork
- the network to use for outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes, org.nd4j.linalg.api.ndarray.INDArray guesses)
Note that an IllegalArgumentException is thrown if the two passed in matrices aren't the same length.
realOutcomes
- the real outcomes (labels - usually binary)guesses
- the guesses/prediction (usually a probability vector)public void eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes, org.nd4j.linalg.api.ndarray.INDArray guesses, java.util.List<? extends java.io.Serializable> recordMetaData)
eval
in interface IEvaluation<Evaluation>
eval
in class BaseEvaluation<Evaluation>
realOutcomes
- Data labelsguesses
- Network predictionsrecordMetaData
- Optional; may be null. If not null, should have size equal to the number of outcomes/guessespublic void eval(int predictedIdx, int actualIdx)
predictedIdx
- Index of class predicted by the networkactualIdx
- Index of actual classpublic java.lang.String stats()
public java.lang.String stats(boolean suppressWarnings)
suppressWarnings
- whether or not to output warnings related to the evaluation resultspublic double precision(java.lang.Integer classLabel)
classLabel
- the labelpublic double precision(java.lang.Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double precision()
public double recall(java.lang.Integer classLabel)
classLabel
- the labelpublic double recall(java.lang.Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double recall()
public double falsePositiveRate(java.lang.Integer classLabel)
classLabel
- the labelpublic double falsePositiveRate(java.lang.Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falsePositiveRate()
public double falseNegativeRate(java.lang.Integer classLabel)
classLabel
- the labelpublic double falseNegativeRate(java.lang.Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falseNegativeRate()
public double falseAlarmRate()
public double f1(java.lang.Integer classLabel)
classLabel
- the label to calculate f1 forpublic double f1()
public double accuracy()
public double topNAccuracy()
accuracy()
public java.util.Map<java.lang.Integer,java.lang.Integer> truePositives()
public java.util.Map<java.lang.Integer,java.lang.Integer> trueNegatives()
public java.util.Map<java.lang.Integer,java.lang.Integer> falsePositives()
public java.util.Map<java.lang.Integer,java.lang.Integer> falseNegatives()
public java.util.Map<java.lang.Integer,java.lang.Integer> negative()
public java.util.Map<java.lang.Integer,java.lang.Integer> positive()
public void incrementTruePositives(java.lang.Integer classLabel)
public void incrementTrueNegatives(java.lang.Integer classLabel)
public void incrementFalseNegatives(java.lang.Integer classLabel)
public void incrementFalsePositives(java.lang.Integer classLabel)
public void addToConfusion(java.lang.Integer real, java.lang.Integer guess)
real
- the actual guessguess
- the system guesspublic int classCount(java.lang.Integer clazz)
clazz
- the labelpublic int getNumRowCounter()
public int getTopNCorrectCount()
public int getTopNTotalCount()
getNumRowCounter()
,
but may differ in the case of using eval(int, int)
as top N accuracy cannot be calculated in that case
(i.e., requires the full probability distribution, not just predicted/actual indices)public java.lang.String getClassLabel(java.lang.Integer clazz)
public ConfusionMatrix<java.lang.Integer> getConfusionMatrix()
public void merge(Evaluation other)
other
- Evaluation object to merge into this one.public java.lang.String confusionToString()
public java.util.List<Prediction> getPredictionErrors()
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
public java.util.List<Prediction> getPredictionsByActualClass(int actualClass)
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
actualClass
- Actual class to get predictions forpublic java.util.List<Prediction> getPredictionByPredictedClass(int predictedClass)
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
predictedClass
- Actual class to get predictions forpublic java.util.List<Prediction> getPredictions(int actualClass, int predictedClass)
actualClass
- Actual classpredictedClass
- Predicted class