Data Science Asked by Adrian Buzea on February 3, 2021
I’m training a binary classifier and I’d like to see Precision/Recall metrics at different thresholds.
Tensorflow 2.3 introduced tf.keras.metrics.Precision and tf.keras.metrics.Recall which take a thresholds
parameter, where you can specify one or multiple thresholds for which you want the metrics computed. This all works as advertised i.e.
m = tf.keras.metrics.Precision(thresholds=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
m.update_state([0, 1, 0, 1], [0.4, 0.5, 0.3, 0.8])
m.result().numpy()
Returns the precision value at each threshold [0.5, 0.5, 0.6666667, 1., 1., 1.]
, as per the documentation.
However when passed as metrics to Model.compile
I get a single metric regardless of how many thresholds I have.
pr_thresholds = list(np.arange(0.05, 0.95, 0.05))
model.compile(
'adam',
'binary_crossentropy',
metrics=[
keras.metrics.Precision(thresholds=pr_thresholds),
keras.metrics.Recall(thresholds=pr_thresholds),
]
)
I get
Epoch 34/50
395/395 [==============================] - 22s 54ms/step - loss: 0.4314 - precision: 0.7886 - recall: 0.9008 - val_loss: 0.5113 - val_precision: 0.7434 - val_recall: 0.8769
What’s happening here ? Does it always use the default threshold value of 0.5 in this case ?
Is there a way I can get it to display the values for multiple thresholds during training ?
You can see the metrics value for each threshold along the fitting process if you explicitely instantiate the corresponding metric class for each threshold, as follows:
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-2),
loss='categorical_crossentropy',
metrics=[metrics.Recall(thresholds=0.6),
metrics.Recall(thresholds=0.9)])
model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))
and as you can see in the image below, for each epoch you can see that the first recall value (with threshold 0.6) is higher than the second one (threshold 0.9) as expected:
And for your case, to build the list of metrics objects programatically, where you can now see 3 recalls per epoch:
thresholds = [0.6, 0.7, 0.9]
metrics_objs_list=[metrics.Recall(thresholds=thr) for thr in thresholds]
Answered by German C M on February 3, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP