I have written a metric for tensorflow which is the area under the Precision-Recall curve to the left of the recall=0.3 threshold. The implementation is (notice I only care about the zeroth prediction of the last timestep in which window):
def auprc_left_of_recall(y_true, y_pred, recall_threshold=0.3):
y_true = y_true[:, -1, 0]
y_pred = y_pred[:, -1, 0]
recall_threshold = tf.constant(recall_threshold)
indices = tf.argsort(y_pred, direction='DESCENDING')
sorted_y_true = tf.gather(y_true, indices)
cum_true_positives = K.cumsum(sorted_y_true)
cum_false_positives = K.cumsum(1 - sorted_y_true)
total_positives = K.sum(sorted_y_true)
precision = cum_true_positives / (cum_true_positives + cum_false_positives + K.epsilon())
recall = cum_true_positives / (total_positives + K.epsilon())
before_mask = recall < recall_threshold
after_mask = recall >= recall_threshold
idx_before = tf.reduce_max(tf.where(before_mask))
idx_after = tf.reduce_min(tf.where(after_mask))
recall_before = recall[idx_before]
recall_after = recall[idx_after]
precision_before = precision[idx_before]
precision_after = precision[idx_after]
interpolated_precision = precision_before + (precision_after - precision_before) * (recall_threshold - recall_before) / (recall_after
- recall_before)
precision_left = tf.concat([tf.boolean_mask(precision, before_mask), [interpolated_precision]], axis=0)
recall_left = tf.concat([tf.boolean_mask(recall, before_mask), [recall_threshold]], axis=0)
precision_left = tf.concat([[1.0], precision_left], axis=0)
recall_left = tf.concat([[0.0], recall_left], axis=0)
recall_diff = recall_left[1:] - recall_left[:-1]
avg_precision = (precision_left[1:] + precision_left[:-1]) / 2
area = tf.reduce_sum(recall_diff * avg_precision)
return area
However, when aggregated over batches through the mean over batches, the result obviously does not correspond to what I would expect (it works perfectly if there is a single validation/test batch).
How would I implement a metric which evaluates exactly what the metric above does, but over the labels and predictions of all batches at once?