From f9f7cfa5bd5a84bd6ee6bb4e650de624bf918409 Mon Sep 17 00:00:00 2001 From: EIFY Date: Thu, 28 Nov 2024 14:21:54 -0800 Subject: [PATCH] fix contrast() transform See https://github.com/google-research/big_vision/issues/109 Fix suggested by @yeqingli in https://github.com/tensorflow/models/pull/11219#pullrequestreview-2355525720 --- .../imagenet_resnet/imagenet_jax/randaugment.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..95a96db18 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -128,19 +128,7 @@ def color(image, factor): def contrast(image, factor): """Equivalent of PIL Contrast.""" - degenerate = tf.image.rgb_to_grayscale(image) - # Cast before calling tf.histogram. - degenerate = tf.cast(degenerate, tf.int32) - - # Compute the grayscale histogram, then compute the mean pixel value, - # and create a constant image size of that value. Use that as the - # blending degenerate target of the original image. - hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) - mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 - degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean - degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) - degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) - return blend(degenerate, image, factor) + return tf.image.adjust_contrast(image, factor) def brightness(image, factor):