-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-class Precision and Recall #1753
Comments
@Squadrick Please check if this feature is already added in the tensorflow main code base. There is an issue similar to this in the tensorflow repo as well but it seems to be not resolved. Thank you |
We have something in TFX. See https://www.tensorflow.org/tfx/model_analysis/metrics#multi-classmulti-label_classification_metrics |
In case it's useful, I gave an example on how to adapt the existing binary label-oriented metrics for a multi-class setting in tensorflow/tensorflow#37256 (comment). HTH! |
Perhaps I am misunderstanding, but I have been running a multiclass classification model and using the following precision and recall metrics:
One thing I am having trouble with is multiclass classification reports from sklearn - any pointers, other good issue threads people have seen? |
If anyone searches for this, maybe this will help. I created new metric to get multi class confusion matrix, I know we already have one in addons, but it wasn't helping my cause. So the metric looks like this: class MultiClassConfusionMatrix(tf.keras.metrics.Metric):
def __init__(self, num_classes, name="multi_class_confusion_matrix", **kwargs):
super(MultiClassConfusionMatrix, self).__init__(name=name, **kwargs)
self.num_classes = num_classes
self.mctp_conf_matrix = self.add_weight(name="confusion_matrix", shape=(num_classes, num_classes), dtype=tf.dtypes.int32, initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.reshape(y_true, [-1])
y_pred = tf.argmax(y_pred, axis=-1)
tmp = tf.math.confusion_matrix(y_true, y_pred, self.num_classes)
res = tf.math.add(tmp, self.mctp_conf_matrix)
self.mctp_conf_matrix.assign(res)
def result(self):
return self.mctp_conf_matrix
def reset_states(self):
"""
In version 2.5.0 this method is renamed to "reset_state"
"""
self.mctp_conf_matrix.assign(tf.zeros((self.num_classes, self.num_classes), dtype=tf.dtypes.int32)) This way you basically get matrix at the end for example evaluation, after you do If this is something that could help the people I will gladly make a PR. One thing to note is that this class accepts only classes for which input What are your thoughts? |
TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: |
Describe the feature and the current behavior/state.
Please add multi-class precision and recall metrics, much like that in
sklearn.metrics
.Currently,
tf.metrics.Precision
andtf.metrics.Recall
only support binary labels.sklearn.metrics
supports averages of types binary, micro (global average), macro (average of metric per label), weighted (macro, but weighted), and samples.Relevant information
Are you willing to contribute it (yes/no):
If someone can guide me, I am willing to give it a try.
Are you willing to maintain it going forward? (yes/no):
If I implement, then yes.
Is there a relevant academic paper? (if so, where):
The definitions are standard, or see sklearn implementation.
Is there already an implementation in another framework? (if so, where):
Please see sklearn/metrics/_classification.py.
Was it part of tf.contrib? (if so, where):
I couldn't find one.
Which API type would this fall under (layer, metric, optimizer, etc.)
metric.
Who will benefit with this feature?
People who are performing multi-label classification and are looking for precision and recall metrics.
Any other info.
There is already an implementation of f1-score. Perhaps these two metrics can piggy back on that.
Also, these metrics need to mesh with the binary metrics provided by tf.
The text was updated successfully, but these errors were encountered: