TransWikia.com

Keras P/R metrics at different thresholds during training

Data Science Asked by Adrian Buzea on February 3, 2021

I’m training a binary classifier and I’d like to see Precision/Recall metrics at different thresholds.

Tensorflow 2.3 introduced tf.keras.metrics.Precision and tf.keras.metrics.Recall which take a thresholds parameter, where you can specify one or multiple thresholds for which you want the metrics computed. This all works as advertised i.e.

m = tf.keras.metrics.Precision(thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
m.update_state([0, 1, 0, 1], [0.4, 0.5, 0.3, 0.8])
m.result().numpy()

Returns the precision value at each threshold [0.5, 0.5, 0.6666667, 1., 1., 1.], as per the documentation.

However when passed as metrics to Model.compile I get a single metric regardless of how many thresholds I have.


pr_thresholds = list(np.arange(0.05, 0.95, 0.05))
model.compile(
    'adam',
    'binary_crossentropy',
    metrics=[ 
        keras.metrics.Precision(thresholds=pr_thresholds),
        keras.metrics.Recall(thresholds=pr_thresholds),
    ]
)

I get

Epoch 34/50
395/395 [==============================] - 22s 54ms/step - loss: 0.4314 - precision: 0.7886 - recall: 0.9008 - val_loss: 0.5113 - val_precision: 0.7434 - val_recall: 0.8769

What’s happening here ? Does it always use the default threshold value of 0.5 in this case ?

Is there a way I can get it to display the values for multiple thresholds during training ?

One Answer

You can see the metrics value for each threshold along the fitting process if you explicitely instantiate the corresponding metric class for each threshold, as follows:

model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-2),
loss='categorical_crossentropy',
metrics=[metrics.Recall(thresholds=0.6), 
         metrics.Recall(thresholds=0.9)])

model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))

and as you can see in the image below, for each epoch you can see that the first recall value (with threshold 0.6) is higher than the second one (threshold 0.9) as expected:

enter image description here

And for your case, to build the list of metrics objects programatically, where you can now see 3 recalls per epoch:

thresholds = [0.6, 0.7, 0.9]
metrics_objs_list=[metrics.Recall(thresholds=thr) for thr in thresholds]

enter image description here

Answered by German C M on February 3, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP