Java Deep Learning Projects
上QQ阅读APP看书,第一时间看更新

Evaluating the model

Once the training has been completed, the next task would be evaluating the model. We will evaluate the model's performance on the test set. For the evaluation, we will be using Evaluation(); it creates an evaluation object with two possible classes (survived or not survived). More technically, the Evaluation class computes the evaluation metrics such as precision, recall, F1, accuracy, and Matthews' correlation coefficient. The last one is used to evaluate a binary classifier. Now let's take a brief overview on these metrics:

Accuracy is the ratio of correctly predicted samples to total samples:

Precision is the ratio of correctly predicted positive samples to the total predicted positive samples:

 

Recall is the ratio of correctly predicted positive samples to all samples in the actual class—yes:

 

F1 score is the weighted average (harmonic mean) of Precision and Recall::

 

Matthews Correlation Coefficient (MCC) is a measure of the quality of binary (two-class) classifications. MCC can be calculated directly from the confusion matrix as follows (given that TP, FP, TN, and FN are already available):

Unlike the Apache Spark-based classification evaluator, when solving a binary classification problem using the DL4J-based evaluator, special care should be taken for binary classification metrics such as F1, precision, recall, and so on.

Well, we will see these later on. First, let's iterate the evaluation over every test sample and get the network's prediction from the trained model. Finally, the eval() method checks the prediction against the true classes:

log.info("Evaluate model...."); 
Evaluation eval = new Evaluation(2) // for class 1

while(testDataIt.hasNext()){
DataSet next = testDataIt.next();
INDArray output = model.output(next.getFeatureMatrix());
eval.eval(next.getLabels(), output);
}
log.info(eval.stats());
log.info("****************Example finished********************");
>>>
==========================Scores========================================
# of classes: 2
Accuracy: 0.6496
Precision: 0.6155
Recall: 0.5803
F1 Score: 0.3946
Precision, recall & F1: reported for positive class (class 1 - "1") only
=======================================================================

Oops! Unfortunately, we have not managed to achieve very high classification accuracy for class 1 (that is, 65%). Now, we compute another metric called MCC for this binary classification problem.

// Compute Matthews correlation coefficient 
EvaluationAveraging averaging = EvaluationAveraging.Macro;
double MCC = eval.matthewsCorrelation(averaging);
System.out.println("Matthews correlation coefficient: "+ MCC);
>>>
Matthews's correlation coefficient: 0.22308172619187497

Now let's try to interpret this result based on the Matthews paper (see more at www.sciencedirect.com/science/article/pii/0005279575901099), which describes the following properties: A correlation of C = 1 indicates perfect agreement, C = 0 is expected for a prediction no better than random, and C = -1 indicates total disagreement between prediction and observation.

Following this, our result shows a weak positive relationship. Alright! Although we have not achieved good accuracy, you guys can still try by tuning hyperparameters or even by changing other networks such as LSTM, which we are going to discuss in the next section. But we'll do so for solving our cancer prediction problem, which is the main goal of this chapter. So stay with me!