From 47a2a786f33c6c2a95ef6f7f260c75a7b537dfcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 8 May 2023 15:51:05 +0200 Subject: [PATCH 1/8] Numpy Macenko augmentor added --- apps/example_aug.py | 42 ++++++++ example.py => apps/example_norm.py | 6 +- torchstain/__init__.py | 2 +- torchstain/base/__init__.py | 2 +- torchstain/base/augmentors/__init__.py | 2 + torchstain/base/augmentors/he_augmentor.py | 6 ++ torchstain/base/augmentors/macenko.py | 10 ++ torchstain/numpy/__init__.py | 2 +- torchstain/numpy/augmentors/__init__.py | 1 + torchstain/numpy/augmentors/macenko.py | 111 +++++++++++++++++++++ torchstain/numpy/normalizers/macenko.py | 5 +- 11 files changed, 181 insertions(+), 8 deletions(-) create mode 100644 apps/example_aug.py rename example.py => apps/example_norm.py (87%) create mode 100644 torchstain/base/augmentors/__init__.py create mode 100644 torchstain/base/augmentors/he_augmentor.py create mode 100644 torchstain/base/augmentors/macenko.py create mode 100644 torchstain/numpy/augmentors/__init__.py create mode 100644 torchstain/numpy/augmentors/macenko.py diff --git a/apps/example_aug.py b/apps/example_aug.py new file mode 100644 index 0000000..8c67515 --- /dev/null +++ b/apps/example_aug.py @@ -0,0 +1,42 @@ +import cv2 +import matplotlib.pyplot as plt +import torchstain +import torch +from torchvision import transforms +import time +import os + + +size = 1024 +dir_path = os.path.dirname(os.path.abspath(__file__)) +target = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/target.png"), cv2.COLOR_BGR2RGB), (size, size)) +to_transform = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/source.png"), cv2.COLOR_BGR2RGB), (size, size)) + +augmentor = torchstain.augmentors.MacenkoAugmentor(backend='numpy') +augmentor.fit(to_transform) + +#T = transforms.Compose([ +# transforms.ToTensor(), +# transforms.Lambda(lambda x: x*255) +#]) + +#t_to_transform = T(to_transform) + +plt.figure() +plt.suptitle('numpy augmentor') +plt.subplot(4, 4, 1) +plt.title('Original') +plt.axis('off') +plt.imshow(to_transform) + +for i in range(16): + # generate augmented sample + result = augmentor.augment() + + plt.subplot(4, 4, i + 1) + if i == 1: + plt.title('Augmented ->') + plt.axis('off') + plt.imshow(result) + +plt.show() diff --git a/example.py b/apps/example_norm.py similarity index 87% rename from example.py rename to apps/example_norm.py index 98ba5b9..893ad81 100644 --- a/example.py +++ b/apps/example_norm.py @@ -4,11 +4,13 @@ import torch from torchvision import transforms import time +import os size = 1024 -target = cv2.resize(cv2.cvtColor(cv2.imread("./data/target.png"), cv2.COLOR_BGR2RGB), (size, size)) -to_transform = cv2.resize(cv2.cvtColor(cv2.imread("./data/source.png"), cv2.COLOR_BGR2RGB), (size, size)) +dir_path = os.path.dirname(os.path.abspath(__file__)) +target = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/target.png"), cv2.COLOR_BGR2RGB), (size, size)) +to_transform = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/source.png"), cv2.COLOR_BGR2RGB), (size, size)) normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') normalizer.fit(target) diff --git a/torchstain/__init__.py b/torchstain/__init__.py index 4b11e31..d8a6bc8 100644 --- a/torchstain/__init__.py +++ b/torchstain/__init__.py @@ -1,3 +1,3 @@ __version__ = '1.3.0' -from torchstain.base import normalizers \ No newline at end of file +from torchstain.base import augmentors, normalizers \ No newline at end of file diff --git a/torchstain/base/__init__.py b/torchstain/base/__init__.py index d66541f..59f3b07 100644 --- a/torchstain/base/__init__.py +++ b/torchstain/base/__init__.py @@ -1 +1 @@ -from torchstain.base import normalizers \ No newline at end of file +from torchstain.base import augmentors, normalizers \ No newline at end of file diff --git a/torchstain/base/augmentors/__init__.py b/torchstain/base/augmentors/__init__.py new file mode 100644 index 0000000..2eb3a59 --- /dev/null +++ b/torchstain/base/augmentors/__init__.py @@ -0,0 +1,2 @@ +from .he_augmentor import HEAugmentor +from .macenko import MacenkoAugmentor diff --git a/torchstain/base/augmentors/he_augmentor.py b/torchstain/base/augmentors/he_augmentor.py new file mode 100644 index 0000000..35c627c --- /dev/null +++ b/torchstain/base/augmentors/he_augmentor.py @@ -0,0 +1,6 @@ +class HEAugmentor: + def fit(self, I): + pass + + def augment(self): + raise Exception('Abstract method') diff --git a/torchstain/base/augmentors/macenko.py b/torchstain/base/augmentors/macenko.py new file mode 100644 index 0000000..67c6e73 --- /dev/null +++ b/torchstain/base/augmentors/macenko.py @@ -0,0 +1,10 @@ +def MacenkoAugmentor(backend='torch'): + if backend == 'numpy': + from torchstain.numpy.augmentors import NumpyMacenkoAugmentor + return NumpyMacenkoAugmentor() + elif backend == "torch": + raise NotImplementedError() + elif backend == "tensorflow": + raise NotImplementedError() + else: + raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/numpy/__init__.py b/torchstain/numpy/__init__.py index 113e6f1..b8aaf02 100644 --- a/torchstain/numpy/__init__.py +++ b/torchstain/numpy/__init__.py @@ -1 +1 @@ -from torchstain.numpy import normalizers, utils +from torchstain.numpy import augmentors, normalizers, utils diff --git a/torchstain/numpy/augmentors/__init__.py b/torchstain/numpy/augmentors/__init__.py new file mode 100644 index 0000000..7ed0af2 --- /dev/null +++ b/torchstain/numpy/augmentors/__init__.py @@ -0,0 +1 @@ +from .macenko import NumpyMacenkoAugmentor \ No newline at end of file diff --git a/torchstain/numpy/augmentors/macenko.py b/torchstain/numpy/augmentors/macenko.py new file mode 100644 index 0000000..382ab03 --- /dev/null +++ b/torchstain/numpy/augmentors/macenko.py @@ -0,0 +1,111 @@ +import numpy as np +from torchstain.base.augmentors import HEAugmentor + +""" +Source code adapted from: https://github.com/schaugf/HEnorm_python +Original implementation: https://github.com/mitkovetta/staining-normalization +""" +class NumpyMacenkoAugmentor(HEAugmentor): + def __init__(self, sigma1=0.2, sigma2=0.2): + super().__init__() + + self.sigma1 = sigma1 + self.sigma2 = sigma2 + + self.I = None + + self.HERef = np.array([[0.5626, 0.2159], + [0.7201, 0.8012], + [0.4062, 0.5581]]) + self.maxCRef = np.array([1.9705, 1.0308]) + + def __convert_rgb2od(self, I, Io=240, beta=0.15): + # calculate optical density + OD = -np.log((I.astype(float)+1)/Io) + + # remove transparent pixels + ODhat = OD[~np.any(OD < beta, axis=1)] + + return OD, ODhat + + def __find_HE(self, ODhat, eigvecs, alpha): + #project on the plane spanned by the eigenvectors corresponding to the two + # largest eigenvalues + That = ODhat.dot(eigvecs[:,1:3]) + + phi = np.arctan2(That[:,1],That[:,0]) + + minPhi = np.percentile(phi, alpha) + maxPhi = np.percentile(phi, 100-alpha) + + vMin = eigvecs[:, 1:3].dot(np.array([(np.cos(minPhi), np.sin(minPhi))]).T) + vMax = eigvecs[:, 1:3].dot(np.array([(np.cos(maxPhi), np.sin(maxPhi))]).T) + + # a heuristic to make the vector corresponding to hematoxylin first and the + # one corresponding to eosin second + if vMin[0] > vMax[0]: + HE = np.array((vMin[:,0], vMax[:,0])).T + else: + HE = np.array((vMax[:,0], vMin[:,0])).T + + return HE + + def __find_concentration(self, OD, HE): + # rows correspond to channels (RGB), columns to OD values + Y = np.reshape(OD, (-1, 3)).T + + # determine concentrations of the individual stains + C = np.linalg.lstsq(HE, Y, rcond=None)[0] + + return C + + def __compute_matrices(self, I, Io, alpha, beta): + I = I.reshape((-1, 3)) + + OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) + + # compute eigenvectors + _, eigvecs = np.linalg.eigh(np.cov(ODhat.T)) + + HE = self.__find_HE(ODhat, eigvecs, alpha) + + C = self.__find_concentration(OD, HE) + + # normalize stain concentrations + maxC = np.array([np.percentile(C[0,:], 99), np.percentile(C[1,:], 99)]) + + return HE, C, maxC + + def fit(self, I, Io=240, alpha=1, beta=0.15): + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + # keep these as we will use them for augmentation + self.I = I + self.HERef = HE + self.CRef = C + self.maxCRef = maxC + + def augment(self, Io=240, alpha=1, beta=0.15): + I = self.I + h, w, c = I.shape + I = I.reshape((-1, 3)) + + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + maxC = np.divide(maxC, self.maxCRef) + C2 = np.divide(C, maxC[:, np.newaxis]) + + # introduce noise to the concentrations + for i in range(C2.shape[0]): + alpha = np.random.uniform(1 - self.sigma1, 1 + self.sigma1) + beta = np.random.uniform(-self.sigma2, self.sigma2) + + C2[i, :] *= alpha + C2[i, :] += beta + + # recreate the image using reference mixing matrix + Iaug = np.multiply(Io, np.exp(-self.HERef.dot(C2))) + Iaug[Iaug > 255] = 255 + Iaug = np.reshape(Iaug.T, (h, w, c)).astype(np.uint8) + + return Iaug diff --git a/torchstain/numpy/normalizers/macenko.py b/torchstain/numpy/normalizers/macenko.py index faada97..6a05ae3 100644 --- a/torchstain/numpy/normalizers/macenko.py +++ b/torchstain/numpy/normalizers/macenko.py @@ -10,8 +10,8 @@ def __init__(self): super().__init__() self.HERef = np.array([[0.5626, 0.2159], - [0.7201, 0.8012], - [0.4062, 0.5581]]) + [0.7201, 0.8012], + [0.4062, 0.5581]]) self.maxCRef = np.array([1.9705, 1.0308]) def __convert_rgb2od(self, I, Io=240, beta=0.15): @@ -109,7 +109,6 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True): Inorm[Inorm > 255] = 255 Inorm = np.reshape(Inorm.T, (h, w, c)).astype(np.uint8) - H, E = None, None if stains: From 4a3f08ba33ea793caddac42d6d69f326f1f27564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 8 May 2023 16:21:30 +0200 Subject: [PATCH 2/8] tf macenko aug added --- apps/example_aug.py | 39 ++++++++-- torchstain/base/augmentors/macenko.py | 3 +- torchstain/numpy/augmentors/macenko.py | 9 +-- torchstain/tf/__init__.py | 2 +- torchstain/tf/augmentors/__init__.py | 1 + torchstain/tf/augmentors/macenko.py | 102 +++++++++++++++++++++++++ 6 files changed, 143 insertions(+), 13 deletions(-) create mode 100644 torchstain/tf/augmentors/__init__.py create mode 100644 torchstain/tf/augmentors/macenko.py diff --git a/apps/example_aug.py b/apps/example_aug.py index 8c67515..6d85aac 100644 --- a/apps/example_aug.py +++ b/apps/example_aug.py @@ -12,15 +12,22 @@ target = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/target.png"), cv2.COLOR_BGR2RGB), (size, size)) to_transform = cv2.resize(cv2.cvtColor(cv2.imread(dir_path + "/../data/source.png"), cv2.COLOR_BGR2RGB), (size, size)) +T = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x*255) +]) + +t_to_transform = T(to_transform) + +# setup augmentors for the different backends augmentor = torchstain.augmentors.MacenkoAugmentor(backend='numpy') augmentor.fit(to_transform) -#T = transforms.Compose([ -# transforms.ToTensor(), -# transforms.Lambda(lambda x: x*255) -#]) +tf_augmentor = torchstain.augmentors.MacenkoAugmentor(backend='tensorflow') +tf_augmentor.fit(t_to_transform) + -#t_to_transform = T(to_transform) +print("NUMPY" + "-"*20) plt.figure() plt.suptitle('numpy augmentor') @@ -40,3 +47,25 @@ plt.imshow(result) plt.show() + + +print("TensorFlow (TF)" + "-"*20) + +plt.figure() +plt.suptitle('tf augmentor') +plt.subplot(4, 4, 1) +plt.title('Original') +plt.axis('off') +plt.imshow(to_transform) + +for i in range(16): + # generate augmented sample + result = tf_augmentor.augment() + + plt.subplot(4, 4, i + 1) + if i == 1: + plt.title('Augmented ->') + plt.axis('off') + plt.imshow(result) + +plt.show() diff --git a/torchstain/base/augmentors/macenko.py b/torchstain/base/augmentors/macenko.py index 67c6e73..ceb4099 100644 --- a/torchstain/base/augmentors/macenko.py +++ b/torchstain/base/augmentors/macenko.py @@ -5,6 +5,7 @@ def MacenkoAugmentor(backend='torch'): elif backend == "torch": raise NotImplementedError() elif backend == "tensorflow": - raise NotImplementedError() + from torchstain.tf.augmentors import TensorFlowMacenkoAugmentor + return TensorFlowMacenkoAugmentor() else: raise Exception(f'Unknown backend {backend}') diff --git a/torchstain/numpy/augmentors/macenko.py b/torchstain/numpy/augmentors/macenko.py index 382ab03..7de8ead 100644 --- a/torchstain/numpy/augmentors/macenko.py +++ b/torchstain/numpy/augmentors/macenko.py @@ -21,7 +21,7 @@ def __init__(self, sigma1=0.2, sigma2=0.2): def __convert_rgb2od(self, I, Io=240, beta=0.15): # calculate optical density - OD = -np.log((I.astype(float)+1)/Io) + OD = -np.log((I.astype(float) + 1) / Io) # remove transparent pixels ODhat = OD[~np.any(OD < beta, axis=1)] @@ -97,11 +97,8 @@ def augment(self, Io=240, alpha=1, beta=0.15): # introduce noise to the concentrations for i in range(C2.shape[0]): - alpha = np.random.uniform(1 - self.sigma1, 1 + self.sigma1) - beta = np.random.uniform(-self.sigma2, self.sigma2) - - C2[i, :] *= alpha - C2[i, :] += beta + C2[i, :] *= np.random.uniform(1 - self.sigma1, 1 + self.sigma1) # multiplicative + C2[i, :] += np.random.uniform(-self.sigma2, self.sigma2) # additative # recreate the image using reference mixing matrix Iaug = np.multiply(Io, np.exp(-self.HERef.dot(C2))) diff --git a/torchstain/tf/__init__.py b/torchstain/tf/__init__.py index edaf586..6c4ca79 100644 --- a/torchstain/tf/__init__.py +++ b/torchstain/tf/__init__.py @@ -1 +1 @@ -from torchstain.tf import normalizers, utils +from torchstain.tf import augmentors, normalizers, utils diff --git a/torchstain/tf/augmentors/__init__.py b/torchstain/tf/augmentors/__init__.py new file mode 100644 index 0000000..f51d11a --- /dev/null +++ b/torchstain/tf/augmentors/__init__.py @@ -0,0 +1 @@ +from .macenko import TensorFlowMacenkoAugmentor diff --git a/torchstain/tf/augmentors/macenko.py b/torchstain/tf/augmentors/macenko.py new file mode 100644 index 0000000..25160c4 --- /dev/null +++ b/torchstain/tf/augmentors/macenko.py @@ -0,0 +1,102 @@ +import tensorflow as tf +from torchstain.base.augmentors.he_augmentor import HEAugmentor +from torchstain.tf.utils import cov, percentile, solveLS +import numpy as np +import tensorflow.keras.backend as K + + +""" +Source code ported from: https://github.com/schaugf/HEnorm_python +Original implementation: https://github.com/mitkovetta/staining-normalization +""" +class TensorFlowMacenkoAugmentor(HEAugmentor): + def __init__(self, sigma1=0.2, sigma2=0.2): + super().__init__() + + self.sigma1 = sigma1 + self.sigma2 = sigma2 + + self.I = None + + self.HERef = tf.constant([[0.5626, 0.2159], + [0.7201, 0.8012], + [0.4062, 0.5581]]) + self.maxCRef = tf.constant([1.9705, 1.0308]) + + def __convert_rgb2od(self, I, Io, beta): + I = tf.transpose(I, [1, 2, 0]) + + # calculate optical density + OD = -tf.math.log((tf.cast(tf.reshape(I, [tf.math.reduce_prod(I.shape[:-1]), I.shape[-1]]), tf.float32) + 1) / Io) + + # remove transparent pixels + ODhat = OD[~tf.math.reduce_any(OD < beta, axis=1)] + + return OD, ODhat + + def __find_HE(self, ODhat, eigvecs, alpha): + # project on the plane spanned by the eigenvectors corresponding to the two + # largest eigenvalues + That = tf.linalg.matmul(ODhat, eigvecs) + phi = tf.math.atan2(That[:, 1], That[:, 0]) + + minPhi = percentile(phi, alpha) + maxPhi = percentile(phi, 100 - alpha) + + vMin = tf.matmul(eigvecs, tf.expand_dims(tf.stack((tf.math.cos(minPhi), tf.math.sin(minPhi))), axis=-1)) + vMax = tf.matmul(eigvecs, tf.expand_dims(tf.stack((tf.math.cos(maxPhi), tf.math.sin(maxPhi))), axis=-1)) + + # a heuristic to make the vector corresponding to hematoxylin first and the + # one corresponding to eosin second + HE = tf.where(vMin[0] > vMax[0], tf.concat((vMin, vMax), axis=1), tf.concat((vMax, vMin), axis=1)) + + return HE + + def __find_concentration(self, OD, HE): + # rows correspond to channels (RGB), columns to OD values + Y = tf.transpose(OD) + + # determine concentrations of the individual stains + return solveLS(HE, Y)[:2] + + def __compute_matrices(self, I, Io, alpha, beta): + OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) + + # compute eigenvectors + _, eigvecs = tf.linalg.eigh(cov(tf.transpose(ODhat))) + eigvecs = eigvecs[:, 1:3] + + HE = self.__find_HE(ODhat, eigvecs, alpha) + + C = self.__find_concentration(OD, HE) + maxC = tf.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) + + return HE, C, maxC + + def fit(self, I, Io=240, alpha=1, beta=0.15): + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + # keep these as we will use them for augmentation + self.I = I + self.HERef = HE + self.CRef = C + self.maxCRef = maxC + + def augment(self, Io=240, alpha=1, beta=0.15): + I = self.I + c, h, w = I.shape + + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + maxC = maxC / self.maxCRef + C2 = C / tf.expand_dims(maxC, axis=-1) + + # introduce noise to the concentrations (applied along axis=0) + C2 = (C2 * tf.random.uniform((2, 1), 1 - self.sigma1, 1 + self.sigma1)) + tf.random.uniform((2, 1), -self.sigma2, self.sigma2) + + # recreate the image using reference mixing matrix + Iaug = Io * tf.math.exp(-tf.linalg.matmul(self.HERef, C2)) + Iaug = tf.clip_by_value(Iaug, 0, 255) + Iaug = tf.cast(tf.reshape(tf.transpose(Iaug), shape=(h, w, c)), tf.int32) + + return Iaug From 14e7b7306bcb5d2c6712f7a14e5ff039e13a11da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 8 May 2023 16:54:18 +0200 Subject: [PATCH 3/8] torch macenko aug works - all backends supported --- apps/example_aug.py | 25 ++++++ torchstain/base/augmentors/macenko.py | 3 +- torchstain/torch/__init__.py | 2 +- torchstain/torch/augmentors/__init__.py | 1 + torchstain/torch/augmentors/macenko.py | 110 ++++++++++++++++++++++++ 5 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 torchstain/torch/augmentors/__init__.py create mode 100644 torchstain/torch/augmentors/macenko.py diff --git a/apps/example_aug.py b/apps/example_aug.py index 6d85aac..46c952d 100644 --- a/apps/example_aug.py +++ b/apps/example_aug.py @@ -26,6 +26,9 @@ tf_augmentor = torchstain.augmentors.MacenkoAugmentor(backend='tensorflow') tf_augmentor.fit(t_to_transform) +torch_augmentor = torchstain.augmentors.MacenkoAugmentor(backend='torch') +torch_augmentor.fit(t_to_transform) + print("NUMPY" + "-"*20) @@ -69,3 +72,25 @@ plt.imshow(result) plt.show() + + +print("Torch" + "-"*20) + +plt.figure() +plt.suptitle('torch augmentor') +plt.subplot(4, 4, 1) +plt.title('Original') +plt.axis('off') +plt.imshow(to_transform) + +for i in range(16): + # generate augmented sample + result = torch_augmentor.augment() + + plt.subplot(4, 4, i + 1) + if i == 1: + plt.title('Augmented ->') + plt.axis('off') + plt.imshow(result) + +plt.show() diff --git a/torchstain/base/augmentors/macenko.py b/torchstain/base/augmentors/macenko.py index ceb4099..6345ae5 100644 --- a/torchstain/base/augmentors/macenko.py +++ b/torchstain/base/augmentors/macenko.py @@ -3,7 +3,8 @@ def MacenkoAugmentor(backend='torch'): from torchstain.numpy.augmentors import NumpyMacenkoAugmentor return NumpyMacenkoAugmentor() elif backend == "torch": - raise NotImplementedError() + from torchstain.torch.augmentors import TorchMacenkoAugmentor + return TorchMacenkoAugmentor() elif backend == "tensorflow": from torchstain.tf.augmentors import TensorFlowMacenkoAugmentor return TensorFlowMacenkoAugmentor() diff --git a/torchstain/torch/__init__.py b/torchstain/torch/__init__.py index fef6329..cbbd0e9 100644 --- a/torchstain/torch/__init__.py +++ b/torchstain/torch/__init__.py @@ -1 +1 @@ -from torchstain.torch import normalizers, utils \ No newline at end of file +from torchstain.torch import augmentors, normalizers, utils \ No newline at end of file diff --git a/torchstain/torch/augmentors/__init__.py b/torchstain/torch/augmentors/__init__.py new file mode 100644 index 0000000..7c9bf4d --- /dev/null +++ b/torchstain/torch/augmentors/__init__.py @@ -0,0 +1 @@ +from .macenko import TorchMacenkoAugmentor diff --git a/torchstain/torch/augmentors/macenko.py b/torchstain/torch/augmentors/macenko.py new file mode 100644 index 0000000..0def8e4 --- /dev/null +++ b/torchstain/torch/augmentors/macenko.py @@ -0,0 +1,110 @@ +import torch +from torchstain.base.augmentors.he_augmentor import HEAugmentor +from torchstain.torch.utils import cov, percentile + +""" +Source code ported from: https://github.com/schaugf/HEnorm_python +Original implementation: https://github.com/mitkovetta/staining-normalization +""" +class TorchMacenkoAugmentor(HEAugmentor): + def __init__(self, sigma1=0.2, sigma2=0.2): + super().__init__() + + self.sigma1 = sigma1 + self.sigma2 = sigma2 + + self.I = None + + self.HERef = torch.tensor([[0.5626, 0.2159], + [0.7201, 0.8012], + [0.4062, 0.5581]]) + self.maxCRef = torch.tensor([1.9705, 1.0308]) + + # Avoid using deprecated torch.lstsq (since 1.9.0) + self.updated_lstsq = hasattr(torch.linalg, 'lstsq') + + def __convert_rgb2od(self, I, Io, beta): + I = I.permute(1, 2, 0) + + # calculate optical density + OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1)/Io) + + # remove transparent pixels + ODhat = OD[~torch.any(OD < beta, dim=1)] + + return OD, ODhat + + def __find_HE(self, ODhat, eigvecs, alpha): + # project on the plane spanned by the eigenvectors corresponding to the two + # largest eigenvalues + That = torch.matmul(ODhat, eigvecs) + phi = torch.atan2(That[:, 1], That[:, 0]) + + minPhi = percentile(phi, alpha) + maxPhi = percentile(phi, 100 - alpha) + + vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1) + vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1) + + # a heuristic to make the vector corresponding to hematoxylin first and the + # one corresponding to eosin second + HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1)) + + return HE + + def __find_concentration(self, OD, HE): + # rows correspond to channels (RGB), columns to OD values + Y = OD.T + + # determine concentrations of the individual stains + if not self.updated_lstsq: + return torch.lstsq(Y, HE)[0][:2] + + return torch.linalg.lstsq(HE, Y)[0] + + def __compute_matrices(self, I, Io, alpha, beta): + OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) + + # compute eigenvectors + _, eigvecs = torch.linalg.eigh(cov(ODhat.T)) + eigvecs = eigvecs[:, [1, 2]] + + HE = self.__find_HE(ODhat, eigvecs, alpha) + + C = self.__find_concentration(OD, HE) + maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)]) + + return HE, C, maxC + + def fit(self, I, Io=240, alpha=1, beta=0.15): + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + # keep these as we will use them for augmentation + self.I = I + self.HERef = HE + self.CRef = C + self.maxCRef = maxC + + @staticmethod + def random_uniform(shape, low, high): + return (low - high) * torch.rand(*shape) + high + + def augment(self, Io=240, alpha=1, beta=0.15): + I = self.I + c, h, w = I.shape + + HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta) + + maxC = maxC / self.maxCRef + C2 = C / torch.unsqueeze(maxC, axis=-1) + + # introduce noise to the concentrations (applied along axis=0) + C2 = (C2 * self.random_uniform((2, 1), 1 - self.sigma1, 1 + self.sigma1)) + self.random_uniform((2, 1), -self.sigma2, self.sigma2) + + # recreate the image using reference mixing matrix + Inorm = Io * torch.exp(-torch.matmul(self.HERef, C2)) + Inorm[Inorm > 255] = 255 + Inorm = Inorm.T.reshape(h, w, c).int() + + return Inorm + \ No newline at end of file From b13ad2a85a4275e00927fbc1fe885036fa5493b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 1 Apr 2024 10:32:03 +0200 Subject: [PATCH 4/8] Added missing transpose to aug numpy macenko --- torchstain/numpy/augmentors/macenko.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstain/numpy/augmentors/macenko.py b/torchstain/numpy/augmentors/macenko.py index 7de8ead..da639b9 100644 --- a/torchstain/numpy/augmentors/macenko.py +++ b/torchstain/numpy/augmentors/macenko.py @@ -60,7 +60,7 @@ def __find_concentration(self, OD, HE): return C def __compute_matrices(self, I, Io, alpha, beta): - I = I.reshape((-1, 3)) + I = I.reshape((-1, 3)).T OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) From c0461441612d21c2c056d3a9fbc98453eb763fcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 1 Apr 2024 10:56:46 +0200 Subject: [PATCH 5/8] Revert change --- torchstain/numpy/augmentors/macenko.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstain/numpy/augmentors/macenko.py b/torchstain/numpy/augmentors/macenko.py index da639b9..7de8ead 100644 --- a/torchstain/numpy/augmentors/macenko.py +++ b/torchstain/numpy/augmentors/macenko.py @@ -60,7 +60,7 @@ def __find_concentration(self, OD, HE): return C def __compute_matrices(self, I, Io, alpha, beta): - I = I.reshape((-1, 3)).T + I = I.reshape((-1, 3)) OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta) From 694f5d0b484b0660a575c66e7e9a3790d1eaff5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 1 Apr 2024 11:04:25 +0200 Subject: [PATCH 6/8] Allow for sigma1 and sigma2 to be set for the Macenko augmentor --- torchstain/base/augmentors/macenko.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchstain/base/augmentors/macenko.py b/torchstain/base/augmentors/macenko.py index 6345ae5..449f248 100644 --- a/torchstain/base/augmentors/macenko.py +++ b/torchstain/base/augmentors/macenko.py @@ -1,12 +1,12 @@ -def MacenkoAugmentor(backend='torch'): +def MacenkoAugmentor(backend='torch', sigma1=0.2, sigma2=0.2): if backend == 'numpy': from torchstain.numpy.augmentors import NumpyMacenkoAugmentor - return NumpyMacenkoAugmentor() + return NumpyMacenkoAugmentor(sigma1=sigma1, sigma2=sigma2) elif backend == "torch": from torchstain.torch.augmentors import TorchMacenkoAugmentor - return TorchMacenkoAugmentor() + return TorchMacenkoAugmentor(sigma1=sigma1, sigma2=sigma2) elif backend == "tensorflow": from torchstain.tf.augmentors import TensorFlowMacenkoAugmentor - return TensorFlowMacenkoAugmentor() + return TensorFlowMacenkoAugmentor(sigma1=sigma1, sigma2=sigma2) else: raise Exception(f'Unknown backend {backend}') From b3100c667d2830ff34137524a78eebb6f9bbdeea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 1 Apr 2024 11:27:06 +0200 Subject: [PATCH 7/8] Tried fixing download artifact step in test CIs --- .github/workflows/tests_full.yml | 4 ++-- .github/workflows/tests_quick.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index 78f0c45..193c356 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -51,7 +51,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Download artifact - uses: actions/download-artifact@master + uses: actions/download-artifact@v3 with: name: "Python wheel" @@ -90,7 +90,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Download artifact - uses: actions/download-artifact@master + uses: actions/download-artifact@v3 with: name: "Python wheel" diff --git a/.github/workflows/tests_quick.yml b/.github/workflows/tests_quick.yml index 188fb02..9441a35 100644 --- a/.github/workflows/tests_quick.yml +++ b/.github/workflows/tests_quick.yml @@ -43,7 +43,7 @@ jobs: python-version: 3.8 - name: Download artifact - uses: actions/download-artifact@master + uses: actions/download-artifact@v3 with: name: "Python wheel" @@ -70,7 +70,7 @@ jobs: python-version: 3.8 - name: Download artifact - uses: actions/download-artifact@master + uses: actions/download-artifact@v3 with: name: "Python wheel" From bf8e77d7c740ab8940e91d50dac9e0ab4ab57afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Mon, 1 Apr 2024 11:31:04 +0200 Subject: [PATCH 8/8] Run tests on macos-12 --- .github/workflows/tests_full.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests_full.yml b/.github/workflows/tests_full.yml index 193c356..31f4564 100644 --- a/.github/workflows/tests_full.yml +++ b/.github/workflows/tests_full.yml @@ -39,7 +39,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ windows-2019, ubuntu-20.04, macos-11 ] + os: [ windows-2019, ubuntu-20.04, macos-12 ] python-version: [ 3.7, 3.8, 3.9 ] tf-version: [2.7.0, 2.8.0, 2.9.0] @@ -71,7 +71,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ windows-2019, ubuntu-20.04, macos-11 ] + os: [ windows-2019, ubuntu-20.04, macos-12 ] python-version: [ 3.6, 3.7, 3.8, 3.9 ] pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] exclude: